Skip to content

Commit

Permalink
fix demo examples (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc authored Apr 12, 2024
1 parent a273133 commit 8e45d0b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
20 changes: 12 additions & 8 deletions design/examples/loglogistic.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ library(posterior)
library(bayesplot)
library(here)

# library(jmpost)
devtools::load_all()


dat <- flexsurv::bc |>
as_tibble() |>
mutate(arm = "A", study = "S", pt = sprintf("pt-%05d", 1:n()))
mutate(arm = "A", study = "S", pt = sprintf("pt-%05d", seq_len(n())))


#
Expand Down Expand Up @@ -101,8 +104,7 @@ k <- 2
# JMpost
#

devtools::load_all()
# library(jmpost)


jm <- JointModel(
survival = SurvivalLogLogistic()
Expand Down Expand Up @@ -131,20 +133,22 @@ mp <- sampleStanModel(
)

vars <- c(
"sm_logl_lambda",
"sm_logl_p"
"sm_loglogis_a",
"sm_loglogis_b"
)

x <- mp@results$summary(vars)
stanobj <- as.CmdStanMCMC(mp)

x <- stanobj$summary(vars)

c(
"scale" = 1 / x$mean[1],
"scale" = x$mean[1],
"shape" = x$mean[2]
)


# Log Likelihood
log_lik <- mp@results$draws("log_lik", format = "draws_matrix") |>
log_lik <- stanobj$draws("log_lik", format = "draws_matrix") |>
apply(1, sum) |>
mean()
log_lik
Expand Down
22 changes: 14 additions & 8 deletions design/examples/weibull.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ devtools::load_all()

dat <- flexsurv::bc |>
as_tibble() |>
mutate(arm = "A", study = "S", pt = sprintf("pt-%05d", 1:n()))
mutate(arm = "A", study = "S", pt = sprintf("pt-%05d", seq_len(n())))


#
Expand Down Expand Up @@ -300,10 +300,11 @@ vars <- c(
"beta_os_cov"
)

mp@results$summary(vars)
stanobj <- as.CmdStanMCMC(mp)
stanobj$summary(vars)

# Log Likelihood
log_lik <- mp@results$draws("log_lik", format = "draws_matrix") |>
log_lik <- stanobj$draws("log_lik", format = "draws_matrix") |>
apply(1, sum) |>
mean()
log_lik
Expand All @@ -316,19 +317,22 @@ k <- 2
(4 * log(nrow(dat))) + (-2 * log_lik)

# Leave one out CV
mp@results$loo()
stanobj$loo()


#### Extract Desired Quantities

prediction_times <- seq(min(dat$recyrs), max(dat$recyrs), length.out = 20)
selected_patients <- c("pt-00681", "pt-00002")


# Survival plots
sq_surv <- SurvivalQuantities(
mp,
time_grid = prediction_times,
groups = selected_patients,
grid = GridFixed(
subjects = selected_patients,
times = prediction_times
),
type = "surv"
)
autoplot(sq_surv, add_km = FALSE, add_wrap = FALSE)
Expand All @@ -338,8 +342,10 @@ summary(sq_surv)
# Hazard
sq_haz <- SurvivalQuantities(
mp,
time_grid = prediction_times,
groups = selected_patients,
grid = GridFixed(
subjects = selected_patients,
times = prediction_times
),
type = "haz"
)
autoplot(sq_haz, add_km = FALSE, add_wrap = FALSE)

0 comments on commit 8e45d0b

Please sign in to comment.