Skip to content

Commit

Permalink
Merge branch 'i63-i73-optimization' of https://github.com/epiverse-tr…
Browse files Browse the repository at this point in the history
…ace/serofoi into i63-i73-optimization
  • Loading branch information
rccreswell committed May 3, 2024
2 parents e277ecd + e2088e5 commit 13aaf57
Show file tree
Hide file tree
Showing 38 changed files with 558 additions and 269 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: serofoi
Type: Package
Title: Estimates the Force-of-Infection of a given pathogen from population based seroprevalence studies
Version: 0.0.9
Version: 0.1.0
Authors@R:
c(
person(
Expand Down Expand Up @@ -36,7 +36,7 @@ License: MIT + file LICENSE
Encoding: UTF-8
Language: en-GB
LazyData: true
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
NeedsCompilation: yes
Depends:
R (>= 3.5.0)
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export(extract_seromodel_summary)
export(fit_seromodel)
export(fit_seromodel_optimization)
export(generate_sim_data)
export(get_chunk_structure)
export(get_cohort_ages)
export(get_foi_central_estimates)
export(get_prev_expanded)
Expand Down
2 changes: 1 addition & 1 deletion R/model_comparison.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' @examples
#' data(chagas2012)
#' serodata <- prepare_serodata(serodata = chagas2012)
#' model_constant <- run_seromodel(
#' model_constant <- fit_seromodel(
#' serodata = serodata,
#' foi_model = "constant",
#' iter = 1500
Expand Down
169 changes: 123 additions & 46 deletions R/modelling.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ stop_if_wrong_type <- function(serodata, col_types) {
}
}

warn_missing <- function(serodata, optional_cols) {
warn_missing <- function(
serodata,
optional_cols
) {
if (
!all(
optional_cols
Expand All @@ -52,14 +55,16 @@ warn_missing <- function(serodata, optional_cols) {
) {
missing <- optional_cols[which(!(optional_cols %in% colnames(serodata)))]
warning(
"The following optional columns in `serodata` are missing.",
"The following optional columns in `serodata` are missing. ",
"Consider including them to get more information from this analysis:\n",
toString(missing)
)
for (col in missing) {
serodata[[col]] <- "None" # TODO Shouln't we use `NA` instead?
}
}

return(serodata)
}


Expand All @@ -85,12 +90,16 @@ validate_serodata <- function(serodata) {
antibody = c("character", "factor")
)

warn_missing(serodata,
# Add missing columns
serodata <- warn_missing(
serodata,
optional_cols = names(optional_col_types)
)

# If any optional column is present, validates that is has the correct type
stop_if_wrong_type(serodata, optional_col_types)

return(serodata)
}

validate_prepared_serodata <- function(serodata) {
Expand All @@ -104,15 +113,19 @@ validate_prepared_serodata <- function(serodata) {
prev_obs_lower = "numeric",
prev_obs_upper = "numeric"
)
validate_serodata(serodata)
stop_if_missing(serodata, must_have_cols = names(col_types))
serodata <- validate_serodata(serodata)

stop_if_missing(serodata, must_have_cols = names(col_types))
stop_if_wrong_type(serodata, col_types)

return(serodata)
}

#' Run specified stan model for the force-of-infection and
#' estimate the seroprevalence based on the result of the fit
#'
#' Starting on v.0.1.0, this function will be DEPRECATED. Use `fit_seromodel`
#' instead.
#' This function runs the specified model for the force-of-infection `foi_model`
#' using the data from a seroprevalence survey `serodata` as the input data. See
#' [fit_seromodel] for further details.
Expand All @@ -124,24 +137,29 @@ validate_prepared_serodata <- function(serodata) {
#' the implementation of the model. For further details refer to
#' [fit_seromodel].
#' @examples
#' \dontrun{
#' data(chagas2012)
#' serodata <- prepare_serodata(chagas2012)
#' run_seromodel(
#' serodata,
#' foi_model = "constant"
#' )
#' }
#' @export
run_seromodel <- function(
serodata,
foi_model = c("constant", "tv_normal_log", "tv_normal"),
foi_location = 0,
foi_scale = 1,
chunk_size = 1,
chunks = NULL,
iter = 1000,
thin = 2,
adapt_delta = 0.90,
max_treedepth = 10,
chains = 4,
seed = 12345,
print_summary = TRUE,
...) {
.Deprecated("fit_seromodel")
foi_model <- match.arg(foi_model)
survey <- unique(serodata$survey)
if (length(survey) > 1) {
Expand All @@ -150,11 +168,13 @@ run_seromodel <- function(
seromodel_object <- fit_seromodel(
serodata = serodata,
foi_model = foi_model,
foi_location = foi_location,
foi_scale = foi_scale,
chunk_size = chunk_size,
chunks = chunks,
iter = iter,
thin = thin,
adapt_delta = adapt_delta,
max_treedepth = max_treedepth,
chains = chains,
seed = seed,
...
)
Expand Down Expand Up @@ -200,10 +220,23 @@ run_seromodel <- function(
#' \item{`"tv_normal"`}{Runs a normal model}
#' \item{`"tv_normal_log"`}{Runs a normal logarithmic model}
#' }
#' @param foi_location Location parameter of the force-of-infection distribution
#' of the selected model. Depending on `foi_model`, the meaning may vary.
#' @param foi_scale Scale parameter of the force-of-infection distribution
#' of the selected model. Depending on `foi_model`, the meaning may vary.
#' @param chunks Numeric list specifying the chunk structure of the time
#' interval from the birth year of the oldest age cohort
#' `min(serodata$age_mean_f)` to the time when the serosurvey was conducted
#' `t_sur`. If `NULL`, the time interval is divided in chunks of size
#' `chunk_size`.
#' @param chunk_size Size of the chunks to be used in case that the chunk
#' structure `chunks` is not specified in [fit_seromodel].
#' Default is 1, meaning that one force of infection value is to be estimated
#' for every year in the time interval spanned by the serosurvey.
#' If the length of the time interval is not exactly divisible by `chunk_size`,
#' the remainder years are included in the last chunk.
#' @param iter Number of interactions for each chain including the warmup.
#' `iter` in [sampling][rstan::sampling].
#' @param thin Positive integer specifying the period for saving samples.
#' `thin` in [sampling][rstan::sampling].
#' @param adapt_delta Real number between 0 and 1 that represents the target
#' average acceptance probability. Increasing the value of `adapt_delta` will
#' result in a smaller step size and fewer divergences. For further details
Expand All @@ -212,8 +245,6 @@ run_seromodel <- function(
#' @param max_treedepth Maximum tree depth for the binary tree used in the NUTS
#' stan sampler. For further details refer to the `control` parameter in
#' [sampling][rstan::sampling].
#' @param chains Number of Markov chains for sampling. For further details refer
#' to the `chains` parameter in [sampling][rstan::sampling].
#' @param seed For further details refer to the `seed` parameter in
#' [sampling][rstan::sampling].
#' @param init_value optional list, each element of which is a named list
Expand All @@ -233,74 +264,77 @@ run_seromodel <- function(
fit_seromodel <- function(
serodata,
foi_model = c("constant", "tv_normal_log", "tv_normal"),
foi_location = 0,
foi_scale = 1,
chunks = NULL,
chunk_size = 1,
iter = 1000,
thin = 2,
adapt_delta = 0.90,
max_treedepth = 10,
chains = 4,
seed = 12345,
init_value = NULL,
...) {
# TODO Add a warning because there are exceptions where a minimal amount of
# iterations is needed
# Validate arguments
validate_prepared_serodata(serodata)
serodata <- validate_prepared_serodata(serodata)
stopifnot(
"foi_model must be either `constant`, `tv_normal_log`, or `tv_normal`" =
foi_model %in% c("constant", "tv_normal_log", "tv_normal"),
"iter must be numeric" = is.numeric(iter),
"thin must be numeric" = is.numeric(thin),
"adapt_delta must be numeric" = is.numeric(adapt_delta),
"max_treedepth must be numeric" = is.numeric(max_treedepth),
"chains must be numeric" = is.numeric(chains),
"seed must be numeric" = is.numeric(seed)
)
model <- stanmodels[[foi_model]]
cohort_ages <- get_cohort_ages(serodata = serodata)
exposure_matrix <- get_exposure_matrix(serodata)
n_obs <- nrow(serodata)

if (is.null(chunks)) {
chunks <- get_chunk_structure(
serodata = serodata,
chunk_size = chunk_size
)
}
checkmate::assert_class(chunks, "numeric")
stopifnot(
"`chunks` length must be equal to `max(serodata$age_mean_f)`" =
length(chunks) == max(serodata$age_mean_f)
)

stan_data <- list(
n_obs = n_obs,
n_pos = serodata$counts,
n_total = serodata$total,
age_max = max(cohort_ages$age),
observation_exposure_matrix = exposure_matrix
age_max = max(serodata$age_mean_f),
observation_exposure_matrix = exposure_matrix,
chunks = chunks,
foi_location = foi_location,
foi_scale = foi_scale
)

warmup <- floor(iter / 2)
if (is.null(init_value)) {
if (foi_model == "tv_normal_log") {
f_init <- function() {
list(log_foi = rep(-3, nrow(cohort_ages)))
}
} else {
f_init <- function() {
list(foi = rep(0.01, nrow(cohort_ages)))
}
if (foi_model == "tv_normal_log") {
f_init <- function() {
list(log_fois = rep(-3, max(chunks)))
}
} else {
f_init <- init_value
f_init <- function() {
list(fois = rep(0.01, max(chunks)))
}
}

seromodel_fit <- rstan::sampling(
model,
data = stan_data,
iter = iter,
init = f_init,
warmup = warmup,
control = list(
adapt_delta = adapt_delta,
max_treedepth = max_treedepth
),
thin = thin,
chains = chains,
seed = seed,
# https://github.com/stan-dev/rstan/issues/761#issuecomment-647029649
chain_id = 0,
include = FALSE,
pars = "fois_vector",
verbose = FALSE,
refresh = 0,
...
...
)

if (seromodel_fit@mode == 0) {
Expand All @@ -314,6 +348,50 @@ fit_seromodel <- function(
}


#' Generate list containing the chunk structure to be used in the retrospective
#' estimation of the force of infection.
#'
#' This function generates a numeric list specifying the chunk structure of the
#' time interval spanning from the year of birth of the oldest age cohort up to
#' the time when the serosurvey was conducted.
#' @inheritParams fit_seromodel
#' @examples
#' data(chagas2012)
#' serodata <- prepare_serodata(serodata = chagas2012, alpha = 0.05)
#' cohort_ages <- get_cohort_ages(serodata = serodata)
#' @export
get_chunk_structure <- function(
serodata,
chunk_size
) {
checkmate::assert_int(
chunk_size,
lower = 1,
upper = max(serodata$age_mean_f)
)

chunks <- unlist(
purrr::map(
seq(
1,
max(serodata$age_mean_f) / chunk_size,
1),
rep,
times = chunk_size
)
)

chunks <- append(
chunks,
rep(
max(chunks),
max(serodata$age_mean_f) - length(chunks)
)
)

return(chunks)
}

#' Generate data frame containing the age of each cohort
#' corresponding to each birth year excluding the year of the survey.
#'
Expand Down Expand Up @@ -439,7 +517,7 @@ get_foi_central_estimates <- function(seromodel_object,
#' @examples
#' data(chagas2012)
#' serodata <- prepare_serodata(chagas2012)
#' seromodel_object <- run_seromodel(
#' seromodel_object <- fit_seromodel(
#' serodata = serodata,
#' foi_model = "constant"
#' )
Expand All @@ -449,6 +527,7 @@ get_foi_central_estimates <- function(seromodel_object,
#' @export
extract_seromodel_summary <- function(seromodel_object,
serodata) {
serodata <- validate_prepared_serodata(serodata)
#------- Loo estimates
# The argument parameter_name refers to the name given to the Log-likelihood
# in the stan models. See loo::extract_log_lik() documentation for further
Expand All @@ -463,7 +542,7 @@ extract_seromodel_summary <- function(seromodel_object,
} else {
lll <- c(-1e10, 0)
}
#-------

model_summary <- data.frame(
foi_model = seromodel_object@model_name,
dataset = unique(serodata$survey),
Expand Down Expand Up @@ -513,7 +592,7 @@ extract_seromodel_summary <- function(seromodel_object,
#' @examples
#' data(chagas2012)
#' serodata <- prepare_serodata(chagas2012)
#' seromodel_object <- run_seromodel(
#' seromodel_object <- fit_seromodel(
#' serodata = serodata,
#' foi_model = "constant"
#' )
Expand All @@ -532,8 +611,6 @@ get_prev_expanded <- function(foi,
bin_data <- FALSE
}


dim_foi <- dim(foi)[2]
foi_expanded <- foi

ly <- NCOL(foi_expanded)
Expand Down
Loading

0 comments on commit 13aaf57

Please sign in to comment.