-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added validation function for serodata
and fit_seromodel
. Fixes #148
#154
Changes from all commits
2323e1a
f10c0ba
ef879d7
b7752df
a066e39
b8bbe17
2c42e9e
1d9612b
2a5820a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,115 @@ | ||
stop_if_missing <- function(serodata, must_have_cols) { | ||
if ( | ||
!all( | ||
must_have_cols | ||
%in% colnames(serodata) | ||
) | ||
) { | ||
missing <- must_have_cols[which(!(must_have_cols %in% colnames(serodata)))] | ||
stop( | ||
"The following mandatory columns in `serodata` are missing.\n", | ||
toString(missing) | ||
) | ||
} | ||
} | ||
|
||
stop_if_wrong_type <- function(serodata, col_types) { | ||
error_messages <- list() | ||
for (col in names(col_types)) { | ||
# valid_col_types <- ifelse(is.list(col_types[[col]]), | ||
# col_types[[col]], as.list(col_types[[col]]) | ||
# ) | ||
valid_col_types <- as.list(col_types[[col]]) | ||
|
||
# Only validates column type if the column exists in the dataframe | ||
if (col %in% colnames(serodata) && | ||
!any(vapply(valid_col_types, function(type) { | ||
do.call(sprintf("is.%s", type), list(serodata[[col]])) | ||
}, logical(1)))) { | ||
error_messages <- append( | ||
error_messages, | ||
sprintf( | ||
"`%s` must be of any of these types: `%s`", | ||
col, toString(col_types[[col]]) | ||
) | ||
) | ||
} | ||
} | ||
if (length(error_messages) > 0) { | ||
stop( | ||
"The following columns in `serodata` have wrong types:\n", | ||
toString(error_messages) | ||
) | ||
} | ||
} | ||
|
||
warn_missing <- function(serodata, optional_cols) { | ||
if ( | ||
!all( | ||
optional_cols | ||
%in% colnames(serodata) | ||
) | ||
) { | ||
missing <- optional_cols[which(!(optional_cols %in% colnames(serodata)))] | ||
warning( | ||
"The following optional columns in `serodata` are missing.", | ||
"Consider including them to get more information from this analysis:\n", | ||
toString(missing) | ||
) | ||
for (col in missing) { | ||
serodata[[col]] <- "None" # TODO Shouln't we use `NA` instead? | ||
} | ||
} | ||
} | ||
|
||
|
||
validate_serodata <- function(serodata) { | ||
col_types <- list( | ||
survey = c("character", "factor"), | ||
total = "numeric", | ||
counts = "numeric", | ||
tsur = "numeric", | ||
age_min = "numeric", | ||
age_max = "numeric" | ||
) | ||
|
||
stop_if_missing(serodata, | ||
must_have_cols = names(col_types) | ||
) | ||
|
||
stop_if_wrong_type(serodata, col_types) | ||
|
||
optional_col_types <- list( | ||
country = c("character", "factor"), | ||
test = c("character", "factor"), | ||
antibody = c("character", "factor") | ||
) | ||
|
||
warn_missing(serodata, | ||
optional_cols = names(optional_col_types) | ||
) | ||
|
||
# If any optional column is present, validates that is has the correct type | ||
stop_if_wrong_type(serodata, optional_col_types) | ||
} | ||
|
||
validate_prepared_serodata <- function(serodata) { | ||
col_types <- list( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For the time being we can also add |
||
total = "numeric", | ||
counts = "numeric", | ||
tsur = "numeric", | ||
age_mean_f = "numeric", | ||
birth_year = "numeric", | ||
prev_obs = "numeric", | ||
prev_obs_lower = "numeric", | ||
prev_obs_upper = "numeric" | ||
) | ||
validate_serodata(serodata) | ||
stop_if_missing(serodata, must_have_cols = names(col_types)) | ||
|
||
stop_if_wrong_type(serodata, col_types) | ||
} | ||
|
||
#' Function that runs the specified stan model for the Force-of-Infection and | ||
#' estimates the seroprevalence based on the result of the fit | ||
#' | ||
|
@@ -15,22 +127,21 @@ | |
#' data(chagas2012) | ||
#' serodata <- prepare_serodata(chagas2012) | ||
#' run_seromodel( | ||
#' serodata, | ||
#' foi_model = "constant" | ||
#' serodata, | ||
#' foi_model = "constant" | ||
#' ) | ||
#' @export | ||
run_seromodel <- function( | ||
serodata, | ||
foi_model = c("constant", "tv_normal_log", "tv_normal"), | ||
iter = 1000, | ||
thin = 2, | ||
adapt_delta = 0.90, | ||
max_treedepth = 10, | ||
chains = 4, | ||
seed = "12345", | ||
print_summary = TRUE, | ||
... | ||
) { | ||
serodata, | ||
foi_model = c("constant", "tv_normal_log", "tv_normal"), | ||
iter = 1000, | ||
thin = 2, | ||
adapt_delta = 0.90, | ||
max_treedepth = 10, | ||
chains = 4, | ||
seed = 12345, | ||
print_summary = TRUE, | ||
...) { | ||
foi_model <- match.arg(foi_model) | ||
survey <- unique(serodata$survey) | ||
if (length(survey) > 1) { | ||
|
@@ -118,19 +229,28 @@ run_seromodel <- function( | |
#' | ||
#' @export | ||
fit_seromodel <- function( | ||
serodata, | ||
foi_model = c("constant", "tv_normal_log", "tv_normal"), | ||
iter = 1000, | ||
thin = 2, | ||
adapt_delta = 0.90, | ||
max_treedepth = 10, | ||
chains = 4, | ||
seed = "12345", | ||
... | ||
) { | ||
serodata, | ||
foi_model = c("constant", "tv_normal_log", "tv_normal"), | ||
iter = 1000, | ||
thin = 2, | ||
adapt_delta = 0.90, | ||
max_treedepth = 10, | ||
chains = 4, | ||
seed = 12345, | ||
...) { | ||
# TODO Add a warning because there are exceptions where a minimal amount of | ||
# iterations is needed | ||
foi_model <- match.arg(foi_model) | ||
# Validate arguments | ||
validate_prepared_serodata(serodata) | ||
stopifnot( | ||
"foi_model must be either `constant`, `tv_normal_log`, or `tv_normal`" = foi_model %in% c("constant", "tv_normal_log", "tv_normal"), | ||
"iter must be numeric" = is.numeric(iter), | ||
"thin must be numeric" = is.numeric(thin), | ||
"adapt_delta must be numeric" = is.numeric(adapt_delta), | ||
"max_treedepth must be numeric" = is.numeric(max_treedepth), | ||
"chains must be numeric" = is.numeric(chains), | ||
"seed must be numeric" = is.numeric(seed) | ||
) | ||
model <- stanmodels[[foi_model]] | ||
cohort_ages <- get_cohort_ages(serodata = serodata) | ||
exposure_matrix <- get_exposure_matrix(serodata) | ||
|
@@ -415,7 +535,7 @@ get_prev_expanded <- function(foi, | |
exposure_expanded[apply( | ||
lower.tri(exposure_expanded, diag = TRUE), | ||
1, rev | ||
)] <- 1 | ||
)] <- 1 | ||
|
||
prev_pn <- t(1 - exp(-exposure_expanded %*% t(foi_expanded))) | ||
|
||
|
@@ -424,16 +544,17 @@ get_prev_expanded <- function(foi, | |
prev_pn, | ||
2, | ||
function(x) { | ||
quantile(x, | ||
c( | ||
0.5, | ||
predicted_prev_lower_quantile, | ||
predicted_prev_upper_quantile | ||
) | ||
) | ||
} | ||
) | ||
quantile( | ||
x, | ||
c( | ||
0.5, | ||
predicted_prev_lower_quantile, | ||
predicted_prev_upper_quantile | ||
) | ||
) | ||
} | ||
) | ||
) | ||
colnames(predicted_prev) <- c( | ||
"predicted_prev", | ||
"predicted_prev_lower", | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add both
age_min
andage_max
to this list