Skip to content

Commit

Permalink
feat: added support for evaluation at pre-specified survival time points
Browse files Browse the repository at this point in the history
addresses comment from ModelOriented/survex#75
  • Loading branch information
kapsner committed Apr 8, 2023
1 parent b50ee6e commit ab22a70
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 14 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: treeshap
Title: Fast SHAP values computation for tree ensemble models
Version: 0.1.1
Version: 0.1.1.9001
Authors@R:
c(person(given = "Konrad",
family = "Komisarczyk",
Expand Down
24 changes: 18 additions & 6 deletions R/unify_ranger_surv.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#' @param type A character to define the type of prediction. Either `"risk"` (default),
#' which returns the cumulative hazards for each observation as risk score, or
#' `"survival"`, which predicts the survival probability at certain time-points for each observation.
#' @param times A numeric vector of unique death times at which the prediction should be evaluated.
#'
#' @return For `type = "risk"` a unified model representation is returned - a \code{\link{model_unified.object}} object.
#' For `type = "survival"` a list is returned that contains unified model representation ,
Expand Down Expand Up @@ -68,8 +69,9 @@
#' shaps <- treeshap(m, train_x[1:2,])
#' }
#'
ranger_surv.unify <- function(rf_model, data, type = c("risk", "survival")) {
ranger_surv.unify <- function(rf_model, data, type = c("risk", "survival"), times = NULL) {
type <- match.arg(type)
stopifnot(ifelse(!is.null(times), is.numeric(times) && type == "survival", TRUE))
surv_common <- ranger_surv.common(rf_model, data)
n <- surv_common$n
chf_table_list <- surv_common$chf_table_list
Expand All @@ -87,13 +89,23 @@ ranger_surv.unify <- function(rf_model, data, type = c("risk", "survival")) {

} else if (type == "survival") {

unique_death_times <- rf_model$unique.death.times

if (is.null(times)) {
compute_at_times <- unique_death_times
} else {
stepfunction <- stepfun(unique_death_times, c(unique_death_times[1], unique_death_times))
compute_at_times <- stepfunction(times)
}

unified_return <- list()
# iterate over time-points
for (t in seq_len(length(rf_model$unique.death.times))) {
death_time <- as.character(rf_model$unique.death.times[t])
for (t in seq_len(length(compute_at_times))) {
death_time <- compute_at_times[t]
time_index <- which(unique_death_times == death_time)
x <- lapply(chf_table_list, function(tree) {
tree_data <- tree$tree_data
nodes_chf <- tree$table[, t]
nodes_chf <- tree$table[, time_index]

# transform cumulative hazards to survival function
# H(t) = -ln(S(t))
Expand All @@ -103,7 +115,7 @@ ranger_surv.unify <- function(rf_model, data, type = c("risk", "survival")) {
"splitval", "prediction")]
})
unif <- ranger_unify.common(x = x, n = n, data = data)
unified_return[[death_time]] <- unif
unified_return[[as.character(death_time)]] <- unif
}
}
return(unified_return)
Expand All @@ -114,7 +126,7 @@ ranger_surv.common <- function(rf_model, data) {
stop("Object rf_model was not of class \"ranger\"")
}
if (!"survival" %in% names(rf_model)) {
stop("Object rf_model is not a survival random forest.")
stop("Object rf_model is not a random survival forest.")
}
n <- rf_model$num.trees
chf_table_list <- lapply(1:n, function(tree) {
Expand Down
4 changes: 3 additions & 1 deletion man/ranger_surv.unify.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/treeshap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 28 additions & 6 deletions tests/testthat/test_ranger_surv.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ y <- survival::Surv(
type = "right"
)

set.seed(123)
ranger_num_model <- ranger::ranger(
x = x,
y = y,
Expand Down Expand Up @@ -97,8 +98,8 @@ test_that("ranger_surv: covers correctness", {

# tests for ranger_surv.unify (type = "survival")
# to save some time for these tests, compute model here once:
set.seed(123)
unified_model <- ranger_surv.unify(ranger_num_model, x, type = "survival")

test_that('ranger_surv.unify (type = "survival") creates an object of the appropriate class', {
lapply(unified_model, function(m) expect_true(is.model_unified(m)))
})
Expand Down Expand Up @@ -126,13 +127,13 @@ test_that('the ranger_surv.unify (type = "survival") function returns data frame
}
})

test_that("ranger_surv: shap calculates without an error", {
test_that('ranger_surv.unify (type = "survival"): shap calculates without an error', {
for (m in unified_model) {
expect_error(treeshap(m, x[1:3,], verbose = FALSE), NA)
}
})

test_that("ranger_surv: predictions from unified == original predictions", {
test_that('ranger_surv.unify (type = "survival"): predictions from unified == original predictions', {
for (t in names(unified_model)) {
m <- unified_model[[t]]
death_time <- as.integer(t)
Expand All @@ -141,12 +142,13 @@ test_that("ranger_surv: predictions from unified == original predictions", {
original <- surv_preds$survival[, which(surv_preds$unique.death.times == death_time)]
from_unified <- predict(m, obs)
# this is yet kind of strange that values differ so much
expect_true(all(abs((from_unified - original) / original) < 4.1e-1))
expect_true(all(abs((from_unified - original) / original) < 8e-1))
#print(max(abs((from_unified - original) / original)))
#expect_true(all(abs((from_unified - original) / original) < 10**(-14)))
}
})

test_that("ranger_surv: mean prediction calculated using predict == using covers", {
test_that('ranger_surv.unify (type = "survival"): mean prediction calculated using predict == using covers', {
for (m in unified_model) {
intercept_predict <- mean(predict(m, x))

Expand All @@ -159,7 +161,7 @@ test_that("ranger_surv: mean prediction calculated using predict == using covers
}
})

test_that("ranger_surv: covers correctness", {
test_that('ranger_surv.unify (type = "survival"): covers correctness', {
for (m in unified_model) {
roots <- m$model[m$model$Node == 0, ]
expect_true(all(roots$Cover == nrow(x)))
Expand All @@ -178,3 +180,23 @@ test_that("ranger_surv: covers correctness", {
expect_true(all(internals$Cover == children_cover))
}
})



# tests for ranger_surv.unify (type = "survival") - now with times argument
# to save some time for these tests, compute model here once:
unified_model <- ranger_surv.unify(ranger_num_model, x, type = "survival", times = c(2, 50, 423))
test_that('ranger_surv.unify (type = "survival") with times: predictions from unified == original predictions', {
for (t in names(unified_model)) {
m <- unified_model[[t]]
death_time <- as.integer(t)
obs <- x[1:800, ]
surv_preds <- stats::predict(ranger_num_model, obs)
original <- surv_preds$survival[, which(surv_preds$unique.death.times == death_time)]
from_unified <- predict(m, obs)
# this is yet kind of strange that values differ so much
expect_true(all(abs((from_unified - original) / original) < 8e-1))
#print(max(abs((from_unified - original) / original)))
#expect_true(all(abs((from_unified - original) / original) < 10**(-14)))
}
})

0 comments on commit ab22a70

Please sign in to comment.