Skip to content

Commit

Permalink
added type checking functions for serodata before and after calling…
Browse files Browse the repository at this point in the history
… `prepare_serodata`
  • Loading branch information
jpavlich committed Feb 3, 2024
1 parent 1b90cf6 commit 4977060
Showing 1 changed file with 74 additions and 19 deletions.
93 changes: 74 additions & 19 deletions R/modelling.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# TODO Complete @param documentation

stop_if_missing <- function(serodata, must_have_cols) {
if (
!all(
Expand All @@ -15,6 +13,36 @@ stop_if_missing <- function(serodata, must_have_cols) {
}
}

stop_if_wrong_type <- function(serodata, col_types) {
error_messages <- list()
for (col in names(col_types)) {
# valid_col_types <- ifelse(is.list(col_types[[col]]),
# col_types[[col]], as.list(col_types[[col]])
# )
valid_col_types <- as.list(col_types[[col]])

# Only validates column type if the column exists in the dataframe
if (col %in% colnames(serodata) &&
!any(vapply(valid_col_types, function(type) {
do.call(sprintf("is.%s", type), list(serodata[[col]]))
}, logical(1)))) {
error_messages <- append(
error_messages,
sprintf(
"`%s` must be of any of these types: `%s`",
col, toString(col_types[[col]])
)
)
}
}
if (length(error_messages) > 0) {
stop(
"The following columns in `serodata` have wrong types: ",
toString(error_messages)
)
}
}

warn_missing <- function(serodata, optional_cols) {
if (
!all(
Expand All @@ -34,14 +62,34 @@ warn_missing <- function(serodata, optional_cols) {
}
}


validate_serodata <- function(serodata) {
col_types <- list(
survey = c("character", "factor"),
total = "numeric",
counts = "numeric",
tsur = "numeric"
)

stop_if_missing(serodata,
must_have_cols = c("survey", "total", "counts", "tsur")
must_have_cols = names(col_types)
)

stop_if_wrong_type(serodata, col_types)

optional_col_types <- list(
country = c("character", "factor"),
test = c("character", "factor"),
antibody = c("character", "factor")
)

warn_missing(serodata,
optional_cols = c("country", "test", "antibody")
optional_cols = names(optional_col_types)
)

# If any optional column is present, validates that is has the correct type
stop_if_wrong_type(serodata, optional_col_types)

# Check that the serodata has the necessary columns to fully
# identify the age groups
stopifnot(
Expand All @@ -55,9 +103,17 @@ validate_serodata <- function(serodata) {
}

validate_prepared_serodata <- function(serodata) {
stop_if_missing(serodata,
must_have_cols = c("total", "counts", "tsur", "age_mean_f", "birth_year")
col_types <- list(
total = "numeric",
counts = "numeric",
tsur = "numeric",
age_mean_f = "numeric",
birth_year = "numeric"
)

stop_if_missing(serodata, must_have_cols = names(col_types))

stop_if_wrong_type(serodata, col_types)
}

#' Function that runs the specified stan model for the Force-of-Infection and
Expand Down Expand Up @@ -217,11 +273,9 @@ fit_seromodel <- function(serodata,
# TODO Add a warning because there are exceptions where a minimal amount of
# iterations is needed
# Validate arguments
# validate_serodata(serodata)
validate_prepared_serodata(serodata)
stopifnot(
"foi_model must be either `constant`, `tv_normal_log`, or `tv_normal`"
= foi_model %in% c("constant", "tv_normal_log", "tv_normal"),
"foi_model must be either `constant`, `tv_normal_log`, or `tv_normal`" = foi_model %in% c("constant", "tv_normal_log", "tv_normal"),
"n_iters must be numeric" = is.numeric(n_iters),
"n_thin must be numeric" = is.numeric(n_thin),
"delta must be numeric" = is.numeric(delta),
Expand Down Expand Up @@ -513,7 +567,7 @@ get_prev_expanded <- function(foi,
exposure_expanded[apply(
lower.tri(exposure_expanded, diag = TRUE),
1, rev
)] <- 1
)] <- 1

prev_pn <- t(1 - exp(-exposure_expanded %*% t(foi_expanded)))

Expand All @@ -522,16 +576,17 @@ get_prev_expanded <- function(foi,
prev_pn,
2,
function(x) {
quantile(x,
c(
0.5,
predicted_prev_lower_quantile,
predicted_prev_upper_quantile
)
)
}
)
quantile(
x,
c(
0.5,
predicted_prev_lower_quantile,
predicted_prev_upper_quantile
)
)
}
)
)
colnames(predicted_prev) <- c(
"predicted_prev",
"predicted_prev_lower",
Expand Down

0 comments on commit 4977060

Please sign in to comment.