Skip to content

Commit

Permalink
fix bug in weighted_survival_score()
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Sep 19, 2023
1 parent 87d69b2 commit 7b70fd8
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions R/integrated_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ weighted_survival_score = function(loss, truth, distribution, times, t_max, p_ma
unique_times = .c_get_unique_times(truth[, "time"], times)
}

is_distr_or_array =
(inherits(distribution, "Distribution")) ||
(inherits(distribution, "array") & length(dim(array)) == 3)

if (is_distr_or_array) {
# get the cdf
if (inherits(distribution, "Distribution")) {
cdf = as.matrix(distribution$cdf(unique_times))
} else {
}
else if (inherits(distribution, "array") &
length(dim(distribution)) == 3) {
# 'distribution' is a survival 3d array so create an
# `Arrdist` using the 'median' curve
arrdistr = distr6::as.Distribution(1 - distribution, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
cdf = as.matrix(arrdistr$cdf(unique_times))
} else { # 'distribution' is a survival 2d matrix
mtc = findInterval(unique_times, as.numeric(colnames(distribution)))
cdf = 1 - t(distribution[, mtc])
if (any(mtc == 0)) {
Expand All @@ -41,7 +46,7 @@ weighted_survival_score = function(loss, truth, distribution, times, t_max, p_ma
rownames(cdf) = unique_times
}

true_times <- truth[, "time"]
true_times = truth[, "time"]

assert_numeric(true_times, any.missing = FALSE)
assert_numeric(unique_times, any.missing = FALSE)
Expand Down

0 comments on commit 7b70fd8

Please sign in to comment.