Skip to content

Commit

Permalink
Merge pull request #74 from kapsner/feat-global-survshap
Browse files Browse the repository at this point in the history
Feature: aggregation of SurvSHAP values across multiple observation
  • Loading branch information
mikolajsp authored May 22, 2023
2 parents cff0131 + 3bdc619 commit 2843273
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 13 deletions.
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: survex
Title: Explainable Machine Learning in Survival Analysis
Version: 1.0.0.9000
Version: 1.0.0.9001
Authors@R:
c(
person("Mikołaj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")),
Expand All @@ -18,15 +18,16 @@ Description: Survival analysis models are commonly used in medicine and other ar
License: GPL (>= 3)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.1
RoxygenNote: 7.2.3
Depends: R (>= 3.5.0)
Imports:
DALEX (>= 2.2.1),
ggplot2,
kernelshap,
pec,
survival,
patchwork
patchwork,
data.table
Suggests:
censored,
covr,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export(survival_to_cumulative_hazard)
export(theme_default_survex)
export(theme_vertical_default_survex)
export(transform_to_stepfunction)
import(data.table)
import(ggplot2)
import(patchwork)
import(survival)
Expand Down
89 changes: 79 additions & 10 deletions R/surv_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,32 @@ surv_shap <- function(explainer,
B = 25,
exact = FALSE
) {
# make this code work for multiple observations
stopifnot(ifelse(!is.null(y_true),
ifelse(is.matrix(y_true),
nrow(new_observation) == nrow(y_true),
is.null(dim(y_true)) && length(y_true) == 2L),
TRUE))

test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE)
new_observation <- new_observation[, colnames(new_observation) %in% colnames(explainer$data)]

# make this code also work for 1-row matrix
col_index <- which(colnames(new_observation) %in% colnames(explainer$data))
if (is.matrix(new_observation) && nrow(new_observation) == 1) {
new_observation <- as.matrix(t(new_observation[, col_index]))
} else {
new_observation <- new_observation[, col_index]
}

if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)")

if (!is.null(y_true)) {
if (is.matrix(y_true)) {
y_true_ind <- y_true[1, 2]
y_true_time <- y_true[1, 1]
# above, we have already checked that nrows of observations are
# identical to nrows of y_true; thus we do not need to index
# the first row here
y_true_ind <- y_true[, 2]
y_true_time <- y_true[, 1]
} else {
y_true_ind <- y_true[2]
y_true_time <- y_true[1]
Expand All @@ -40,7 +58,8 @@ surv_shap <- function(explainer,

res <- list()
res$eval_times <- explainer$times
res$variable_values <- new_observation
# to display final object correctly, when is.matrix(new_observation) == TRUE
res$variable_values <- as.data.frame(new_observation)

res$result <- switch(calculation_method,
"exact_kernel" = shap_kernel(explainer, new_observation, ...),
Expand Down Expand Up @@ -148,14 +167,64 @@ aggregate_surv_shap <- function(survshap, method) {
use_kernelshap <- function(explainer, new_observation, ...){

predfun <- function(model, newdata){
explainer$predict_survival_function(model, newdata, times=explainer$times)
explainer$predict_survival_function(
model,
newdata,
times = explainer$times
)
}

tmp_res <- kernelshap::kernelshap(explainer$model, new_observation, bg_X = explainer$data,
pred_fun = predfun, verbose=FALSE)
tmp_res_list <- sapply(
X = as.character(seq_len(nrow(new_observation))),
FUN = function(i) {
tmp_res <- kernelshap::kernelshap(
object = explainer$model,
X = new_observation[as.integer(i), ],
bg_X = explainer$data,
pred_fun = predfun,
verbose = FALSE
)
tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
colnames(tmp_shap_values) <- colnames(tmp_res$X)
rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "")
data.table::as.data.table(tmp_shap_values, keep.rownames = TRUE)
},
USE.NAMES = TRUE,
simplify = FALSE
)

shap_values <- aggregate_shap_multiple_observations(
shap_res_list = tmp_res_list,
feature_names = colnames(new_observation)
)

return(shap_values)
}


aggregate_shap_multiple_observations <- function(shap_res_list, feature_names) {

shap_values <- data.frame(t(sapply(tmp_res$S, cbind)))
colnames(shap_values) <- colnames(tmp_res$X)
rownames(shap_values) <- paste("t=", explainer$times, sep = "")
if (length(shap_res_list) > 1) {

full_survshap_results <- data.table::rbindlist(
l = shap_res_list,
use.names = TRUE,
idcol = TRUE
)

# compute arithmetic mean for each time-point and feature across
# multiple observations
tmp_res <- full_survshap_results[
, lapply(.SD, mean), by = "rn", .SDcols = feature_names
]
} else {
# no aggregation required
tmp_res <- shap_res_list[[1]]
}
shap_values <- tmp_res[, .SD, .SDcols = setdiff(colnames(tmp_res), "rn")]
# transform to data.frame to make everything compatible with
# previous code
shap_values <- data.frame(shap_values)
rownames(shap_values) <- tmp_res$rn
return(shap_values)
}
2 changes: 2 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#' @import data.table
NULL
21 changes: 21 additions & 0 deletions tests/testthat/test-predict_parts.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ test_that("survshap explanations work", {

})

test_that("global survshap explanations with kernelshap work for ranger", {
veteran <- survival::veteran

rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5)
rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE)

parts_ranger <- predict_parts(
rsf_ranger_exp,
veteran[1:40, !colnames(veteran) %in% c("time", "status")],
y_true = Surv(veteran$time[1:40], veteran$status[1:40]),
aggregation_method = "mean_absolute",
calculation_method = "kernelshap"
)
plot(parts_ranger)

expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap"))
expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp$times))
expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp$data)))

})


test_that("survlime explanations work", {

Expand Down

0 comments on commit 2843273

Please sign in to comment.