Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added validation function for serodata and fit_seromodel. Fixes #148 #154

Merged
merged 9 commits into from
Feb 13, 2024
189 changes: 155 additions & 34 deletions R/modelling.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,115 @@
stop_if_missing <- function(serodata, must_have_cols) {
if (
!all(
must_have_cols
%in% colnames(serodata)
)
) {
missing <- must_have_cols[which(!(must_have_cols %in% colnames(serodata)))]
stop(
"The following mandatory columns in `serodata` are missing.\n",
toString(missing)
)
}
}

stop_if_wrong_type <- function(serodata, col_types) {
error_messages <- list()
for (col in names(col_types)) {
# valid_col_types <- ifelse(is.list(col_types[[col]]),
# col_types[[col]], as.list(col_types[[col]])
# )
valid_col_types <- as.list(col_types[[col]])

# Only validates column type if the column exists in the dataframe
if (col %in% colnames(serodata) &&
!any(vapply(valid_col_types, function(type) {
do.call(sprintf("is.%s", type), list(serodata[[col]]))
}, logical(1)))) {
error_messages <- append(
error_messages,
sprintf(
"`%s` must be of any of these types: `%s`",
col, toString(col_types[[col]])
)
)
}
}
if (length(error_messages) > 0) {
stop(
"The following columns in `serodata` have wrong types:\n",
toString(error_messages)
)
}
}

warn_missing <- function(serodata, optional_cols) {
if (
!all(
optional_cols
%in% colnames(serodata)
)
) {
missing <- optional_cols[which(!(optional_cols %in% colnames(serodata)))]
warning(
"The following optional columns in `serodata` are missing.",
"Consider including them to get more information from this analysis:\n",
toString(missing)
)
for (col in missing) {
serodata[[col]] <- "None" # TODO Shouln't we use `NA` instead?
}
}
}


validate_serodata <- function(serodata) {
col_types <- list(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add both age_min and age_max to this list

survey = c("character", "factor"),
total = "numeric",
counts = "numeric",
tsur = "numeric",
age_min = "numeric",
age_max = "numeric"
)

stop_if_missing(serodata,
must_have_cols = names(col_types)
)

stop_if_wrong_type(serodata, col_types)

optional_col_types <- list(
country = c("character", "factor"),
test = c("character", "factor"),
antibody = c("character", "factor")
)

warn_missing(serodata,
optional_cols = names(optional_col_types)
)

# If any optional column is present, validates that is has the correct type
stop_if_wrong_type(serodata, optional_col_types)
}

validate_prepared_serodata <- function(serodata) {
col_types <- list(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total, counts and tsur are already validated by validate_serodata(serodata) in line 113. Here we should just make sure that both age_mean_f and birth_year had been added to the data.

For the time being we can also add prev_obs, prev_obs_lower and prev_obs_upper (which should be numeric) for consistency with the current version of prepare_serodata. Although they're not needed for modelling, they're currently used for plotting purposes, so to simplify data validation for those functions I think it's worth adding them here. In the future we may refactor prepare_serodata for it just to prepare the data for modelling and compute the prevalence with its binomial confidence interval internally in the plotting functions (if we decide to keep the visualization module in the package), but we can decide this later.

total = "numeric",
counts = "numeric",
tsur = "numeric",
age_mean_f = "numeric",
birth_year = "numeric",
prev_obs = "numeric",
prev_obs_lower = "numeric",
prev_obs_upper = "numeric"
)
validate_serodata(serodata)
stop_if_missing(serodata, must_have_cols = names(col_types))

stop_if_wrong_type(serodata, col_types)
}

#' Function that runs the specified stan model for the Force-of-Infection and
#' estimates the seroprevalence based on the result of the fit
#'
Expand All @@ -15,22 +127,21 @@
#' data(chagas2012)
#' serodata <- prepare_serodata(chagas2012)
#' run_seromodel(
#' serodata,
#' foi_model = "constant"
#' serodata,
#' foi_model = "constant"
#' )
#' @export
run_seromodel <- function(
serodata,
foi_model = c("constant", "tv_normal_log", "tv_normal"),
iter = 1000,
thin = 2,
adapt_delta = 0.90,
max_treedepth = 10,
chains = 4,
seed = "12345",
print_summary = TRUE,
...
) {
serodata,
foi_model = c("constant", "tv_normal_log", "tv_normal"),
iter = 1000,
thin = 2,
adapt_delta = 0.90,
max_treedepth = 10,
chains = 4,
seed = 12345,
print_summary = TRUE,
...) {
foi_model <- match.arg(foi_model)
survey <- unique(serodata$survey)
if (length(survey) > 1) {
Expand Down Expand Up @@ -118,19 +229,28 @@ run_seromodel <- function(
#'
#' @export
fit_seromodel <- function(
serodata,
foi_model = c("constant", "tv_normal_log", "tv_normal"),
iter = 1000,
thin = 2,
adapt_delta = 0.90,
max_treedepth = 10,
chains = 4,
seed = "12345",
...
) {
serodata,
foi_model = c("constant", "tv_normal_log", "tv_normal"),
iter = 1000,
thin = 2,
adapt_delta = 0.90,
max_treedepth = 10,
chains = 4,
seed = 12345,
...) {
# TODO Add a warning because there are exceptions where a minimal amount of
# iterations is needed
foi_model <- match.arg(foi_model)
# Validate arguments
validate_prepared_serodata(serodata)
stopifnot(
"foi_model must be either `constant`, `tv_normal_log`, or `tv_normal`" = foi_model %in% c("constant", "tv_normal_log", "tv_normal"),
"iter must be numeric" = is.numeric(iter),
"thin must be numeric" = is.numeric(thin),
"adapt_delta must be numeric" = is.numeric(adapt_delta),
"max_treedepth must be numeric" = is.numeric(max_treedepth),
"chains must be numeric" = is.numeric(chains),
"seed must be numeric" = is.numeric(seed)
)
model <- stanmodels[[foi_model]]
cohort_ages <- get_cohort_ages(serodata = serodata)
exposure_matrix <- get_exposure_matrix(serodata)
Expand Down Expand Up @@ -415,7 +535,7 @@ get_prev_expanded <- function(foi,
exposure_expanded[apply(
lower.tri(exposure_expanded, diag = TRUE),
1, rev
)] <- 1
)] <- 1

prev_pn <- t(1 - exp(-exposure_expanded %*% t(foi_expanded)))

Expand All @@ -424,16 +544,17 @@ get_prev_expanded <- function(foi,
prev_pn,
2,
function(x) {
quantile(x,
c(
0.5,
predicted_prev_lower_quantile,
predicted_prev_upper_quantile
)
)
}
)
quantile(
x,
c(
0.5,
predicted_prev_lower_quantile,
predicted_prev_upper_quantile
)
)
}
)
)
colnames(predicted_prev) <- c(
"predicted_prev",
"predicted_prev_lower",
Expand Down
58 changes: 6 additions & 52 deletions R/seroprevalence_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,60 +45,9 @@
prepare_serodata <- function(serodata = serodata,
alpha = 0.05) {
checkmate::assert_numeric(alpha, lower = 0, upper = 1)
# Check that serodata has the right columns
cols_check <- c("survey", "total", "counts", "tsur")
if (
!all(
cols_check
%in% colnames(serodata)
)
) {
stop(
"serodata must contain the right columns. ",
sprintf(
"Column(s) (%s) are missing.", toString(
cols_check[which(!(cols_check %in% colnames(serodata)))]
)
)
)
}

# Check that the serodata has the necessary columns to fully
# identify the age groups
stopifnot(
"serodata must contain both 'age_min' and 'age_max',
or 'age_mean_f' to fully identify the age groups" =
all(c(
"age_min", "age_max"
) %in% colnames(serodata)) |
"age_mean_f" %in% colnames(serodata)
)

if (!any(colnames(serodata) == "country")) {
warning(
"Column 'country' is missing. ",
"Consider adding it as additional information."
)
serodata$country <- "None"
}
validate_serodata(serodata)


if (!any(colnames(serodata) == "test")) {
warning(
"Column 'test' is missing. ",
"Consider adding it as additional information."
)
serodata$test <- "None"
}

if (!any(colnames(serodata) == "antibody")) {
warning(
"Column 'antibody' is missing. ",
"Consider adding it as additional information."
)
serodata$antibody <- "None"
}

if (!any(colnames(serodata) == "age_mean_f")) {
serodata <- serodata %>%
dplyr::mutate(
Expand All @@ -113,6 +62,7 @@ prepare_serodata <- function(serodata = serodata,
birth_year = .data$tsur - .data$age_mean_f
)
}

serodata <- serodata %>%
cbind(
Hmisc::binconf(
Expand Down Expand Up @@ -314,8 +264,12 @@ generate_sim_data <- function(sim_data,
sample_size_by_age,
seed = seed
)

# TODO Improve simulation of age_min and age_max
sim_data <- sim_data %>%
mutate(
age_min = .data$age_mean_f,
age_max = .data$age_mean_f,
counts = sim_n_seropositive$n_seropositive,
total = sample_size_by_age,
survey = survey_label
Expand Down
2 changes: 1 addition & 1 deletion man/fit_seromodel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/run_seromodel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion tests/testthat/test_sim_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ test_that("simulated data", {
expect_equal(sim_data$prev_obs, prev_exact, tolerance = TRUE)

#----- Test function group_sim_data
sim_data <- sim_data %>% mutate(age_min = age_mean_f, age_max = age_mean_f)
sim_data_grouped <- group_sim_data(sim_data = sim_data)
expect_s3_class(sim_data_grouped, "data.frame")
expect_s3_class(sim_data_grouped$age_group, "factor")
Expand Down
Loading