Skip to content

Commit

Permalink
Merge dcc35a6 into d002fb4
Browse files Browse the repository at this point in the history
  • Loading branch information
ntorresd authored May 10, 2024
2 parents d002fb4 + dcc35a6 commit 45b08ee
Show file tree
Hide file tree
Showing 29 changed files with 585 additions and 415 deletions.
49 changes: 49 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,52 @@
# serofoi development version

## Documentation

* Datasets `simdata_*` were removed from the package and replaced by corresponding code to simulate data in vignettes (see [#184](https://github.com/epiverse-trace/serofoi/pull/184)).

## Breaking changes

* Update R-hat convergence threshold to $\hat{R} < 1.01$ ([Vehtari, Aki, et al. 2021](https://projecteuclid.org/journals/bayesian-analysis/volume-16/issue-2/Rank-Normalization-Folding-and-Localization--An-Improved-R%cb%86-for/10.1214/20-BA1221.full))

* Add `av_normal` model without seroreversion.

* Allow for uniform prior parameters specification for `constant` model $\sim U(a, b)$

* Change initial prior parameters input specification in `fit_seromodel`. Now they are specified by means
parameter `foi_parameter` as follows:

```
# constant model
foi_model <- "constant"
foi_parameter <- list(
foi_a = 0.01,
foi_b = 0.1
)
# normal models
foi_model <- "tv_normal" # "tv_normal_log" or "av_normal"
foi_parameters <- list(
foi_location = 0.1,
foi_scale = 0.05
)
# running the model
seromodel <- fit_seromodel(
serodata = serodata,
foi_model = foi_model,
foi_parameters = foi_parameters
)
```

Note that the meaning of the parameters may vary depending on the model.


## Minor changes

* Add input validation for `plot_rhats`.

* The x-axis label in `plot_foi` and `plot_rhats` is `"age"` or `"year"` depending on the model type.

# serofoi 0.1.0

## New features
Expand Down
17 changes: 14 additions & 3 deletions R/model_comparison.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,24 @@
#' @export
get_table_rhats <- function(seromodel_object,
cohort_ages) {
checkmate::assert_class(seromodel_object, "stanfit")

rhats <- bayesplot::rhat(seromodel_object, "foi")

if (any(is.nan(rhats))) {
rhats[which(is.nan(rhats))] <- 0
warn_msg <- paste0(
length(which(is.nan(rhats))),
" rhat values are `nan`, ",
"indicating the model may not have run correctly for those times.\n",
"Setting those rhat values to `NA`."
)
warning(warn_msg)
rhats[which(is.nan(rhats))] <- NA
}
model_rhats <- data.frame(year = cohort_ages$birth_year, rhat = rhats)
model_rhats$rhat[model_rhats$rhat == 0] <- NA
model_rhats <- data.frame(
year = cohort_ages$birth_year,
rhat = rhats
)

return(model_rhats)
}
166 changes: 126 additions & 40 deletions R/modelling.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ validate_prepared_serodata <- function(serodata) {
#' \dontrun{
#' data(chagas2012)
#' serodata <- prepare_serodata(chagas2012)
#' run_seromodel(
#' fit_seromodel(
#' serodata,
#' foi_model = "constant"
#' )
Expand All @@ -149,10 +149,9 @@ validate_prepared_serodata <- function(serodata) {
run_seromodel <- function(
serodata,
foi_model = c("constant", "tv_normal_log", "tv_normal"),
foi_location = 0,
foi_scale = 1,
chunk_size = 1,
foi_parameters = NULL,
chunks = NULL,
chunk_size = 1,
iter = 1000,
adapt_delta = 0.90,
max_treedepth = 10,
Expand All @@ -168,10 +167,9 @@ run_seromodel <- function(
seromodel_object <- fit_seromodel(
serodata = serodata,
foi_model = foi_model,
foi_location = foi_location,
foi_scale = foi_scale,
chunk_size = chunk_size,
foi_parameters = foi_parameters,
chunks = chunks,
chunk_size = chunk_size,
iter = iter,
adapt_delta = adapt_delta,
max_treedepth = max_treedepth,
Expand Down Expand Up @@ -220,10 +218,13 @@ run_seromodel <- function(
#' \item{`"tv_normal"`}{Runs a normal model}
#' \item{`"tv_normal_log"`}{Runs a normal logarithmic model}
#' }
#' @param foi_location Location parameter of the force-of-infection distribution
#' of the selected model. Depending on `foi_model`, the meaning may vary.
#' @param foi_scale Scale parameter of the force-of-infection distribution
#' of the selected model. Depending on `foi_model`, the meaning may vary.
#' @param foi_parameters List specifying the initial prior parameters of the
#' model `foi_model` to be specified as (e.g.):
#' \describe{
#' \item{`"constant"`}{`list(foi_a = 0, foi_b = 2)`}
#' \item{`"tv_normal"`}{`list(foi_location = 0, foi_scale = 1)`}
#' \item{`"tv_normal_log"`}{`list(foi_location = -6, foi_scale = 4)`}
#' }
#' @param chunks Numeric list specifying the chunk structure of the time
#' interval from the birth year of the oldest age cohort
#' `min(serodata$age_mean_f)` to the time when the serosurvey was conducted
Expand Down Expand Up @@ -261,9 +262,8 @@ run_seromodel <- function(
#' @export
fit_seromodel <- function(
serodata,
foi_model = c("constant", "tv_normal_log", "tv_normal"),
foi_location = 0,
foi_scale = 1,
foi_model = c("constant", "tv_normal_log", "tv_normal", "av_normal"),
foi_parameters = NULL,
chunks = NULL,
chunk_size = 1,
iter = 1000,
Expand All @@ -272,16 +272,48 @@ fit_seromodel <- function(
seed = 12345,
...) {
serodata <- validate_prepared_serodata(serodata)
err_msg <- paste0(
"foi_model must be either ",
"constant, ",
"tv_normal, tv_normal_log, or ",
"av_normal"
)
stopifnot(
"foi_model must be either `constant`, `tv_normal_log`, or `tv_normal`" =
foi_model %in% c("constant", "tv_normal_log", "tv_normal"),
err_msg =
foi_model %in% c(
"constant",
"tv_normal", "tv_normal_log",
"av_normal"
),
"iter must be numeric" = is.numeric(iter),
"seed must be numeric" = is.numeric(seed)
)

# Set default foi parameters
if (is.null(foi_parameters)) {
if (foi_model == "constant") {
foi_parameters <- list(
foi_a = 0,
foi_b = 2
)
} else if (foi_model %in% c("tv_normal", "av_normal")) {
foi_parameters <- list(
foi_location = 0,
foi_scale = 1
)
} else if (foi_model == "tv_normal_log") {
foi_parameters <- list(
foi_location = -6,
foi_scale = 4
)
}
}

# Load Stan model
model <- stanmodels[[foi_model]]
exposure_matrix <- get_exposure_matrix(serodata)
n_obs <- nrow(serodata)

# Set default chunks structure
if (is.null(chunks)) {
chunks <- get_chunk_structure(
serodata = serodata,
Expand All @@ -294,17 +326,48 @@ fit_seromodel <- function(
length(chunks) == max(serodata$age_mean_f)
)

# Build Stan data
stan_data <- list(
n_obs = n_obs,
n_obs = nrow(serodata),
n_pos = serodata$counts,
n_total = serodata$total,
age_max = max(serodata$age_mean_f),
observation_exposure_matrix = exposure_matrix,
chunks = chunks,
foi_location = foi_location,
foi_scale = foi_scale
chunks = chunks
)

if (foi_model %in% c("constant", "tv_normal", "tv_normal_log")) {
exposure_matrix <- get_exposure_matrix(serodata)
stan_data <- append(
stan_data,
list(
observation_exposure_matrix = exposure_matrix
)
)
} else if (foi_model == "av_normal") {
stan_data <- append(
stan_data,
list(ages = serodata$age_mean_f)
)
}

if (foi_model == "constant") {
stan_data <- append(
stan_data,
list(
foi_a = foi_parameters$foi_a,
foi_b = foi_parameters$foi_b
)
)
} else {
stan_data <- append(
stan_data,
list(
foi_location = foi_parameters$foi_location,
foi_scale = foi_parameters$foi_scale
)
)
}

if (foi_model == "tv_normal_log") {
f_init <- function() {
list(log_fois = rep(-3, max(chunks)))
Expand Down Expand Up @@ -447,6 +510,10 @@ get_exposure_matrix <- function(serodata) {
#' model by means of [run_seromodel].
#' @param cohort_ages A data frame containing the age of each cohort
#' corresponding to each birth year.
#' @param lower_quantile Lower quantile used to compute the credible interval of
#' the fitted force-of-infection.
#' @param upper_quantile Lower quantile used to compute the credible interval of
#' the fitted force-of-infection.
#' @return `foi_central_estimates`. Central estimates for the fitted forced FoI
#' @examples
#' data(chagas2012)
Expand All @@ -461,27 +528,37 @@ get_exposure_matrix <- function(serodata) {
#' cohort_ages = cohort_ages
#' )
#' @export
get_foi_central_estimates <- function(seromodel_object,
cohort_ages) {
if (seromodel_object@model_name == "tv_normal_log") {
lower_quantile <- 0.1
upper_quantile <- 0.9
medianv_quantile <- 0.5
} else {
lower_quantile <- 0.05
upper_quantile <- 0.95
medianv_quantile <- 0.5
}
get_foi_central_estimates <- function(
seromodel_object,
cohort_ages,
lower_quantile = 0.05,
upper_quantile = 0.95
) {
# extracts force-of-infection from stan fit
foi <- rstan::extract(seromodel_object, "foi", inc_warmup = FALSE)[[1]]

# defines time scale depending on the type of the model
if(
seromodel_object@model_name %in%
c("constant", "tv_normal", "tv_normal_log")
) {
foi_central_estimates <- data.frame(
year = cohort_ages$birth_year
)
} else if (seromodel_object@model_name == "av_normal") {
foi_central_estimates <- data.frame(
age = rev(cohort_ages$age)
)
}

# generates central estimations
foi_central_estimates <- data.frame(
year = cohort_ages$birth_year,
lower = apply(foi, 2, quantile, lower_quantile),
upper = apply(foi, 2, quantile, upper_quantile),
medianv = apply(foi, 2, quantile, medianv_quantile)
)
foi_central_estimates <- foi_central_estimates %>%
mutate(
lower = apply(foi, 2, quantile, lower_quantile),
upper = apply(foi, 2, quantile, upper_quantile),
medianv = apply(foi, 2, quantile, 0.5)
)

return(foi_central_estimates)
}

Expand Down Expand Up @@ -559,8 +636,17 @@ extract_seromodel_summary <- function(seromodel_object,
seromodel_object = seromodel_object,
cohort_ages = cohort_ages
)
if (!any(rhats$rhat > 1.1)) {

if (all(rhats$rhat <= 1.01)) {
model_summary$converged <- "Yes"
} else {
model_summary$converged <- "No"
warn_msg <- paste0(
length(which(rhats$rhat > 1.01)),
" rhat values are above 1.01. ",
"Running the chains for more iterations is recommended."
)
warning(warn_msg)
}

return(model_summary)
Expand Down
18 changes: 0 additions & 18 deletions R/simdata_constant.R

This file was deleted.

18 changes: 0 additions & 18 deletions R/simdata_large_epi.R

This file was deleted.

18 changes: 0 additions & 18 deletions R/simdata_sw_dec.R

This file was deleted.

Loading

0 comments on commit 45b08ee

Please sign in to comment.