Skip to content

Commit

Permalink
feat: set default behavior for sampling initialization
Browse files Browse the repository at this point in the history
- Set different init functions for log and no-log models
  • Loading branch information
ntorresd committed Aug 15, 2024
1 parent 9d64eae commit a569288
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
41 changes: 41 additions & 0 deletions R/fit_seromodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,38 @@ add_age_group_to_serosurvey <- function(serosurvey) {
return(serosurvey)
}

#' Sets initialization function for sampling
#'
#' @inheritParams fit_seromodel
#' @export
set_foi_init <- function(
foi_init,
is_log_foi,
foi_index
) {

# Set default behavior for initialization
if (is.null(foi_init)) {
config_file <- system.file("extdata", "config.yml", package = "serofoi")
init_default <- config::get(file = config_file, "priors")$defaults$init

if (is_log_foi) {
foi_init <- function() {
list(log_foi_vector = rep(log(init_default), max(foi_index)))
}
} else {
foi_init <- function() {
list(foi_vector = rep(init_default, max(foi_index)))
}
}
}

checkmate::assert_class(foi_init, "function")
checkmate::assert_double(unlist(foi_init()[[1]]), len = max(foi_index))

return(foi_init)
}

#' Runs specified stan model for the force-of-infection
#'
#' @param serosurvey
Expand Down Expand Up @@ -60,6 +92,7 @@ fit_seromodel <- function(
foi_prior = sf_normal(),
foi_sigma_rw = sf_none(),
foi_index = NULL,
foi_init = NULL,
is_seroreversion = FALSE,
seroreversion_prior = sf_uniform(),
...
Expand All @@ -79,6 +112,12 @@ fit_seromodel <- function(
seroreversion_prior = seroreversion_prior
)

foi_init <- set_foi_init(
foi_init = foi_init,
is_log_foi = is_log_foi,
foi_index = stan_data$foi_index
)

# Assigning right name to the model based on user specifications
model_name <- model_type
if (is_log_foi) {
Expand All @@ -91,9 +130,11 @@ fit_seromodel <- function(

# Compile or load Stan model
model <- stanmodels[[model_name]]

seromodel <- rstan::sampling(
model,
data = stan_data,
init = foi_init,
...
)
seromodel@model_name <- model_name
Expand Down
2 changes: 2 additions & 0 deletions inst/extdata/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ default:
# cauchy
location: 0
scale: 1
# init
init: 0.1

0 comments on commit a569288

Please sign in to comment.