Skip to content

Commit

Permalink
Issue 735: Fix truncation slicing when t < truncation (#736)
Browse files Browse the repository at this point in the history
* Fix truncation slicing when t < truncation

* add news item

* add unit tests for truncate vs relying on integration testing

* add documentation for observation models

* Apply suggestions from code review

Co-authored-by: Sebastian Funk <sebastian.funk@lshtm.ac.uk>

* hack around truncate overloading

* use the correct stan functions

* rename truncate -> truncate_obs

* Update NEWS.md

Co-authored-by: James Azam <james.azam@lshtm.ac.uk>

---------

Co-authored-by: Sebastian Funk <sebastian.funk@lshtm.ac.uk>
Co-authored-by: James Azam <james.azam@lshtm.ac.uk>
  • Loading branch information
3 people committed Aug 9, 2024
1 parent 03fc897 commit 22e5b22
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 20 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- a bug was fixed that caused delay option functions to report an error if only the tolerance was specified. By @sbfnk in #716 and reviewed by @jamesmbaazam.
- a bug was fixed where `forecast_secondary()` did not work with fixed delays. By @sbfnk in #717 and reviewed by @seabbs.
- a bug was fixed that caused delay option functions to report an error if only the tolerance was specified. By @sbfnk.
- a bug was fixed that led to the truncation PMF being shortened from the wrong side when the truncation PMF was longer than the supplied data. By @seabbs in #736 and reviewed by @sbfnk and @jamesmbaazam.

## Documentation

Expand Down
2 changes: 1 addition & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ transformed parameters {
);
}
profile("truncate") {
obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0);
obs_reports = truncate_obs(reports[1:ot], trunc_rev_cmf, 0);
}
} else {
obs_reports = reports[1:ot];
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ transformed parameters {
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
0, 1, 1
);
secondary = truncate(secondary, trunc_rev_cmf, 0);
secondary = truncate_obs(secondary, trunc_rev_cmf, 0);
}
}

Expand Down
6 changes: 3 additions & 3 deletions inst/stan/estimate_truncation.stan
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ transformed parameters{
vector[t] last_obs;
// reconstruct latest data without truncation

last_obs = truncate(to_vector(obs[, obs_sets]), trunc_rev_cmf, 1);
last_obs = truncate_obs(to_vector(obs[, obs_sets]), trunc_rev_cmf, 1);
// apply truncation to latest dataset to map back to previous data sets and
// add noise term
for (i in 1:(obs_sets - 1)) {
trunc_obs[1:(end_t[i] - start_t[i] + 1), i] =
truncate(last_obs[start_t[i]:end_t[i]], trunc_rev_cmf, 0) + sigma;
truncate_obs(last_obs[start_t[i]:end_t[i]], trunc_rev_cmf, 0) + sigma;
}
}
}
Expand Down Expand Up @@ -80,7 +80,7 @@ generated quantities {
matrix[delay_type_max[trunc_id] + 1, obs_sets - 1] gen_obs;
// reconstruct all truncated datasets using posterior of the truncation distribution
for (i in 1:obs_sets) {
recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate(
recon_obs[1:(end_t[i] - start_t[i] + 1), i] = truncate_obs(
to_vector(obs[start_t[i]:end_t[i], i]), trunc_rev_cmf, 1
);
}
Expand Down
110 changes: 97 additions & 13 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
// apply day of the week effect
/**
* Apply day of the week effect to reports
*
* This function applies a day of the week effect to a vector of reports.
*
* @param reports Vector of reports to be adjusted.
* @param day_of_week Array of integers representing the day of the week for each report.
* @param effect Vector of day of week effects.
*
* @return A vector of reports adjusted for day of the week effects.
*/
vector day_of_week_effect(vector reports, array[] int day_of_week, vector effect) {
int t = num_elements(reports);
int wl = num_elements(effect);
Expand All @@ -11,30 +21,65 @@ vector day_of_week_effect(vector reports, array[] int day_of_week, vector effect
}
return(scaled_reports);
}
// Scale observations by fraction reported and update log density of
// fraction reported

/**
* Scale observations by fraction reported
*
* This function scales a vector of reports by a fraction observed.
*
* @param reports Vector of reports to be scaled.
* @param frac_obs Real value representing the fraction observed.
*
* @return A vector of scaled reports.
*/
vector scale_obs(vector reports, real frac_obs) {
int t = num_elements(reports);
vector[t] scaled_reports;
scaled_reports = reports * frac_obs;
return(scaled_reports);
}
// Truncate observed data by some truncation distribution
vector truncate(vector reports, vector trunc_rev_cmf, int reconstruct) {

/**
* Truncate observed data by a truncation distribution
*
* This function truncates a vector of reports based on a truncation distribution.
*
* @param reports Vector of reports to be truncated.
* @param trunc_rev_cmf Vector representing the reverse cumulative mass function of the truncation distribution.
* @param reconstruct Integer flag indicating whether to reconstruct (1) or truncate (0) the data.
*
* @return A vector of truncated reports.
*/
vector truncate_obs(vector reports, vector trunc_rev_cmf, int reconstruct) {
int t = num_elements(reports);
int trunc_max = num_elements(trunc_rev_cmf);
vector[t] trunc_reports = reports;
// Calculate cmf of truncation delay
int trunc_max = min(t, num_elements(trunc_rev_cmf));
int first_t = t - trunc_max + 1;
int joint_max = min(t, trunc_max);
int first_t = t - joint_max + 1;
int first_trunc = trunc_max - joint_max + 1;

// Apply cdf of truncation delay to truncation max last entries in reports
if (reconstruct) {
trunc_reports[first_t:t] ./= trunc_rev_cmf[1:trunc_max];
trunc_reports[first_t:t] ./= trunc_rev_cmf[first_trunc:trunc_max];
} else {
trunc_reports[first_t:t] .*= trunc_rev_cmf[1:trunc_max];
trunc_reports[first_t:t] .*= trunc_rev_cmf[first_trunc:trunc_max];
}
return(trunc_reports);
}
// Truncation distribution priors

/**
* Update log density for truncation distribution priors
*
* This function updates the log density for truncation distribution priors.
*
* @param truncation_mean Array of real values for truncation mean.
* @param truncation_sd Array of real values for truncation standard deviation.
* @param trunc_mean_mean Array of real values for mean of truncation mean prior.
* @param trunc_mean_sd Array of real values for standard deviation of truncation mean prior.
* @param trunc_sd_mean Array of real values for mean of truncation standard deviation prior.
* @param trunc_sd_sd Array of real values for standard deviation of truncation standard deviation prior.
*/
void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
array[] real trunc_mean_mean, array[] real trunc_mean_sd,
array[] real trunc_sd_mean, array[] real trunc_sd_sd) {
Expand All @@ -50,7 +95,22 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
}
}
}
// update log density for reported cases

/**
* Update log density for reported cases
*
* This function updates the log density for reported cases based on the specified model type.
*
* @param cases Array of integer observed cases.
* @param cases_time Array of integer time indices for observed cases.
* @param reports Vector of expected reports.
* @param rep_phi Array of real values for reporting overdispersion.
* @param phi_mean Real value for mean of reporting overdispersion prior.
* @param phi_sd Real value for standard deviation of reporting overdispersion prior.
* @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial).
* @param weight Real value for weighting the log density contribution.
* @param accumulate Integer flag indicating whether to accumulate reports (1) or not (0).
*/
void report_lp(array[] int cases, array[] int cases_time, vector reports,
array[] real rep_phi, real phi_mean, real phi_sd,
int model_type, real weight, int accumulate) {
Expand Down Expand Up @@ -96,7 +156,20 @@ void report_lp(array[] int cases, array[] int cases_time, vector reports,
}
}
}
// update log likelihood (as above but not vectorised and returning log likelihood)

/**
* Calculate log likelihood for reported cases
*
* This function calculates the log likelihood for reported cases based on the specified model type.
*
* @param cases Array of integer observed cases.
* @param reports Vector of expected reports.
* @param rep_phi Array of real values for reporting overdispersion.
* @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial).
* @param weight Real value for weighting the log likelihood contribution.
*
* @return A vector of log likelihoods for each time point.
*/
vector report_log_lik(array[] int cases, vector reports,
array[] real rep_phi, int model_type, real weight) {
int t = num_elements(reports);
Expand All @@ -115,7 +188,18 @@ vector report_log_lik(array[] int cases, vector reports,
}
return(log_lik);
}
// sample reported cases from the observation model

/**
* Generate random samples of reported cases
*
* This function generates random samples of reported cases based on the specified model type.
*
* @param reports Vector of expected reports.
* @param rep_phi Array of real values for reporting overdispersion.
* @param model_type Integer indicating the model type (0 for Poisson, >0 for Negative Binomial).
*
* @return An array of integer sampled reports.
*/
array[] int report_rng(vector reports, array[] real rep_phi, int model_type) {
int t = num_elements(reports);
array[t] int sampled_reports;
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ generated quantities {
delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist,
0, 1, 1
);
reports[i] = to_row_vector(truncate(
reports[i] = to_row_vector(truncate_obs(
to_vector(reports[i]), trunc_rev_cmf, 0)
);
}
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/simulate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ generated quantities {
delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist,
0, 1, 1
);
secondary = truncate(
secondary = truncate_obs(
secondary, trunc_rev_cmf, 0
);
}
Expand Down
30 changes: 30 additions & 0 deletions tests/testthat/test-stan-truncate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
skip_on_cran()
skip_on_os("windows")

test_that("truncate_obs() can perform truncation as expected", {
reports <- c(10, 20, 30, 40, 50)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2)
expected <- c(reports[1], reports[2:5] * trunc_rev_cmf)
expect_equal(truncate_obs(reports, trunc_rev_cmf, FALSE), expected)
})

test_that("truncate_obs() can perform reconstruction as expected", {
reports <- c(10, 20, 15, 8, 10)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2)
expected <- c(reports[1], reports[2:5] / trunc_rev_cmf)
expect_equal(truncate_obs(reports, trunc_rev_cmf, TRUE), expected)
})

test_that("truncate_obs() can handle longer trunc_rev_cmf than reports", {
reports <- c(10, 20, 30)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2, 0.1)
expected <- reports * trunc_rev_cmf[3:5]
expect_equal(truncate_obs(reports, trunc_rev_cmf, FALSE), expected)
})

test_that("truncate_obs() can handle reconstruction with longer trunc_rev_cmf than reports", {
reports <- c(10, 16, 15)
trunc_rev_cmf <- c(1, 0.8, 0.5, 0.2, 0.1)
expected <- reports / trunc_rev_cmf[3:5]
expect_equal(truncate_obs(reports, trunc_rev_cmf, TRUE), expected)
})

0 comments on commit 22e5b22

Please sign in to comment.