From 00d6a65926df1df50f3325c3d4885bb8031ba83f Mon Sep 17 00:00:00 2001 From: kapsner Date: Wed, 5 Apr 2023 21:28:40 +0200 Subject: [PATCH 1/2] feat: aggregate survshap across multiple observations --- DESCRIPTION | 7 ++- NAMESPACE | 1 + R/surv_shap.R | 89 +++++++++++++++++++++++++---- R/zzz.R | 2 + tests/testthat/test-predict_parts.R | 20 +++++++ 5 files changed, 106 insertions(+), 13 deletions(-) create mode 100644 R/zzz.R diff --git a/DESCRIPTION b/DESCRIPTION index 93e5581c..5fc010cc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: survex Title: Explainable Machine Learning in Survival Analysis -Version: 1.0.0.9000 +Version: 1.0.0.9001 Authors@R: c( person("MikoĊ‚aj", "Spytek", email = "mikolajspytek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-7111-2286")), @@ -18,7 +18,7 @@ Description: Survival analysis models are commonly used in medicine and other ar License: GPL (>= 3) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.1 +RoxygenNote: 7.2.3 Depends: R (>= 3.5.0) Imports: DALEX (>= 2.2.1), @@ -26,7 +26,8 @@ Imports: kernelshap, pec, survival, - patchwork + patchwork, + data.table Suggests: censored, covr, diff --git a/NAMESPACE b/NAMESPACE index d055c56b..1dc5c536 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -68,6 +68,7 @@ export(survival_to_cumulative_hazard) export(theme_default_survex) export(theme_vertical_default_survex) export(transform_to_stepfunction) +import(data.table) import(ggplot2) import(patchwork) import(survival) diff --git a/R/surv_shap.R b/R/surv_shap.R index bbb36b36..18775705 100644 --- a/R/surv_shap.R +++ b/R/surv_shap.R @@ -24,14 +24,32 @@ surv_shap <- function(explainer, B = 25, exact = FALSE ) { + # make this code work for multiple observations + stopifnot(ifelse(!is.null(y_true), + ifelse(is.matrix(y_true), + nrow(new_observation) == nrow(y_true), + is.null(dim(y_true)) && length(y_true) == 2L), + TRUE)) + test_explainer(explainer, "surv_shap", has_data = TRUE, has_y = TRUE, has_survival = TRUE) - new_observation <- new_observation[, colnames(new_observation) %in% colnames(explainer$data)] + + # make this code also work for 1-row matrix + col_index <- which(colnames(new_observation) %in% colnames(explainer$data)) + if (is.matrix(new_observation) && nrow(new_observation) == 1) { + new_observation <- as.matrix(t(new_observation[, col_index])) + } else { + new_observation <- new_observation[, col_index] + } + if (ncol(explainer$data) != ncol(new_observation)) stop("New observation and data have different number of columns (variables)") if (!is.null(y_true)) { if (is.matrix(y_true)) { - y_true_ind <- y_true[1, 2] - y_true_time <- y_true[1, 1] + # above, we have already checked that nrows of observations are + # identical to nrows of y_true; thus we do not need to index + # the first row here + y_true_ind <- y_true[, 2] + y_true_time <- y_true[, 1] } else { y_true_ind <- y_true[2] y_true_time <- y_true[1] @@ -40,7 +58,8 @@ surv_shap <- function(explainer, res <- list() res$eval_times <- explainer$times - res$variable_values <- new_observation + # to display final object correctly, when is.matrix(new_observation) == TRUE + res$variable_values <- as.data.frame(new_observation) res$result <- switch(calculation_method, "exact_kernel" = shap_kernel(explainer, new_observation, ...), @@ -148,14 +167,64 @@ aggregate_surv_shap <- function(survshap, method) { use_kernelshap <- function(explainer, new_observation, ...){ predfun <- function(model, newdata){ - explainer$predict_survival_function(model, newdata, times=explainer$times) + explainer$predict_survival_function( + model, + newdata, + times = explainer$times + ) } - tmp_res <- kernelshap::kernelshap(explainer$model, new_observation, bg_X = explainer$data, - pred_fun = predfun, verbose=FALSE) + tmp_res_list <- sapply( + X = as.character(seq_len(nrow(new_observation))), + FUN = function(i) { + tmp_res <- kernelshap::kernelshap( + object = explainer$model, + X = new_observation[as.integer(i), ], + bg_X = explainer$data, + pred_fun = predfun, + verbose = FALSE + ) + tmp_shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) + colnames(tmp_shap_values) <- colnames(tmp_res$X) + rownames(tmp_shap_values) <- paste("t=", explainer$times, sep = "") + data.table::as.data.table(tmp_shap_values, keep.rownames = TRUE) + }, + USE.NAMES = TRUE, + simplify = FALSE + ) + + shap_values <- aggregate_shap_multiple_observations( + shap_res_list = tmp_res_list, + feature_names = colnames(new_observation) + ) + + return(shap_values) +} + + +aggregate_shap_multiple_observations <- function(shap_res_list, feature_names) { - shap_values <- data.frame(t(sapply(tmp_res$S, cbind))) - colnames(shap_values) <- colnames(tmp_res$X) - rownames(shap_values) <- paste("t=", explainer$times, sep = "") + if (length(shap_res_list) > 1) { + + full_survshap_results <- data.table::rbindlist( + l = shap_res_list, + use.names = TRUE, + idcol = TRUE + ) + + # compute arithmetic mean for each time-point and feature across + # multiple observations + tmp_res <- full_survshap_results[ + , lapply(.SD, mean), by = "rn", .SDcols = feature_names + ] + } else { + # no aggregation required + tmp_res <- shap_res_list[[1]] + } + shap_values <- tmp_res[, .SD, .SDcols = setdiff(colnames(tmp_res), "rn")] + # transform to data.frame to make everything compatible with + # previous code + shap_values <- data.frame(shap_values) + rownames(shap_values) <- tmp_res$rn return(shap_values) } diff --git a/R/zzz.R b/R/zzz.R new file mode 100644 index 00000000..4db9ceea --- /dev/null +++ b/R/zzz.R @@ -0,0 +1,2 @@ +#' @import data.table +NULL diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 20539a9f..43321e05 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -48,6 +48,26 @@ test_that("survshap explanations work", { }) +test_that("global survshap explanations with kernelshap work for ranger", { + veteran <- survival::veteran + + rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5) + rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status), verbose = FALSE) + + parts_ranger <- predict_parts( + rsf_ranger_exp, + veteran[1:40, !colnames(veteran) %in% c("time", "status")], + y_true = Surv(veteran$time[1:40], veteran$status[1:40]), + aggregation_method = "mean_absolute", + calculation_method = "kernelshap" + ) + + expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) + expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp$times)) + expect_true(all(colnames(parts_ranger$result) == colnames(rsf_ranger_exp$data))) + +}) + test_that("survlime explanations work", { From 3bdc61976b3b133c3e286c069b39c3b06d3a6625 Mon Sep 17 00:00:00 2001 From: kapsner Date: Thu, 6 Apr 2023 10:54:00 +0200 Subject: [PATCH 2/2] test: added plot for global survshap to unit-test --- tests/testthat/test-predict_parts.R | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/testthat/test-predict_parts.R b/tests/testthat/test-predict_parts.R index 43321e05..15e795f3 100644 --- a/tests/testthat/test-predict_parts.R +++ b/tests/testthat/test-predict_parts.R @@ -61,6 +61,7 @@ test_that("global survshap explanations with kernelshap work for ranger", { aggregation_method = "mean_absolute", calculation_method = "kernelshap" ) + plot(parts_ranger) expect_s3_class(parts_ranger, c("predict_parts_survival", "surv_shap")) expect_equal(nrow(parts_ranger$result), length(rsf_ranger_exp$times))