Skip to content

Commit

Permalink
Merge pull request #419 from stan-dev/faster_CSV_read
Browse files Browse the repository at this point in the history
Faster CSV read with multiple chains
  • Loading branch information
jgabry authored Jan 4, 2021
2 parents 3a652c9 + ef3d394 commit a27d933
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 93 deletions.
75 changes: 33 additions & 42 deletions R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ read_cmdstan_csv <- function(files,
sampler_diagnostics = NULL) {
checkmate::assert_file_exists(files, access = "r", extension = "csv")
metadata <- NULL
warmup_draws <- NULL
warmup_sampler_diagnostics_draws <- NULL
post_warmup_draws <- NULL
post_warmup_sampler_diagnostics_draws <- NULL
generated_quantities <- NULL
warmup_draws <- list()
post_warmup_draws <- list()
warmup_sampler_diagnostics_draws <- list()
post_warmup_sampler_diagnostics_draws <- list()
generated_quantities <- list()
variational_draws <- NULL
point_estimates <- NULL
inv_metric <- list()
Expand Down Expand Up @@ -241,49 +241,25 @@ read_cmdstan_csv <- function(files,
if (metadata$method == "sample") {
if (metadata$save_warmup == 1) {
if (length(variables) > 0) {
warmup_draws <- posterior::bind_draws(
warmup_draws,
posterior::as_draws_array(draws[1:num_warmup_draws, variables, drop = FALSE]),
along="chain"
)
warmup_draws[[length(warmup_draws) + 1]] <- draws[1:num_warmup_draws, variables, drop = FALSE]
if (num_post_warmup_draws > 0) {
post_warmup_draws <- posterior::bind_draws(
post_warmup_draws,
posterior::as_draws_array(draws[(num_warmup_draws+1):all_draws, variables, drop = FALSE]),
along="chain"
)
post_warmup_draws[[length(post_warmup_draws) + 1]] <- draws[(num_warmup_draws+1):all_draws, variables, drop = FALSE]
}
}
if (length(sampler_diagnostics) > 0) {
warmup_sampler_diagnostics_draws <- posterior::bind_draws(
warmup_sampler_diagnostics_draws,
posterior::as_draws_array(draws[1:num_warmup_draws, sampler_diagnostics, drop = FALSE]),
along="chain"
)
warmup_sampler_diagnostics_draws[[length(warmup_sampler_diagnostics_draws) + 1]] <- draws[1:num_warmup_draws, sampler_diagnostics, drop = FALSE]
if (num_post_warmup_draws > 0) {
post_warmup_sampler_diagnostics_draws <- posterior::bind_draws(
post_warmup_sampler_diagnostics_draws,
posterior::as_draws_array(draws[(num_warmup_draws+1):all_draws, sampler_diagnostics, drop = FALSE]),
along="chain"
)
post_warmup_sampler_diagnostics_draws[[length(post_warmup_sampler_diagnostics_draws) + 1]] <- draws[(num_warmup_draws+1):all_draws, sampler_diagnostics, drop = FALSE]
}
}
} else {
warmup_draws <- NULL
warmup_sampler_diagnostics_draws <- NULL
if (length(variables) > 0) {
post_warmup_draws <- posterior::bind_draws(
post_warmup_draws,
posterior::as_draws_array(draws[, variables, drop = FALSE]),
along="chain"
)
post_warmup_draws[[length(post_warmup_draws) + 1]] <- draws[, variables, drop = FALSE]
}
if (length(sampler_diagnostics) > 0 && all(metadata$algorithm != "fixed_param")) {
post_warmup_sampler_diagnostics_draws <- posterior::bind_draws(
post_warmup_sampler_diagnostics_draws,
posterior::as_draws_array(draws[, sampler_diagnostics, drop = FALSE]),
along="chain"
)
post_warmup_sampler_diagnostics_draws[[length(post_warmup_sampler_diagnostics_draws) + 1]] <- draws[, sampler_diagnostics, drop = FALSE]
}
}
} else if (metadata$method == "variational") {
Expand All @@ -300,9 +276,7 @@ read_cmdstan_csv <- function(files,
} else if (metadata$method == "optimize") {
point_estimates <- posterior::as_draws_matrix(draws[1,, drop=FALSE])[, variables]
} else if (metadata$method == "generate_quantities") {
generated_quantities <- posterior::bind_draws(generated_quantities,
posterior::as_draws_array(draws),
along="chain")
generated_quantities[[length(generated_quantities) + 1]] <- draws
}
}
}
Expand All @@ -313,7 +287,6 @@ read_cmdstan_csv <- function(files,
}

metadata$inv_metric <- NULL
metadata$lines_to_skip <- NULL
metadata$model_params <- repair_variable_names(metadata$model_params)
repaired_variables <- repair_variable_names(variables)
if (metadata$method == "variational") {
Expand All @@ -330,12 +303,16 @@ read_cmdstan_csv <- function(files,
metadata$stan_variables <- names(model_param_dims)

if (metadata$method == "sample") {
warmup_draws <- bind_list_of_draws_array(warmup_draws)
if (!is.null(warmup_draws)) {
posterior::variables(warmup_draws) <- repaired_variables
}
post_warmup_draws <- bind_list_of_draws_array(post_warmup_draws)
if (!is.null(post_warmup_draws)) {
posterior::variables(post_warmup_draws) <- repaired_variables
}
warmup_sampler_diagnostics_draws <- bind_list_of_draws_array(warmup_sampler_diagnostics_draws)
post_warmup_sampler_diagnostics_draws <- bind_list_of_draws_array(post_warmup_sampler_diagnostics_draws)
list(
metadata = metadata,
time = list(total = NA_integer_, chains = time),
Expand Down Expand Up @@ -363,6 +340,7 @@ read_cmdstan_csv <- function(files,
point_estimates = point_estimates
)
} else if (metadata$method == "generate_quantities") {
generated_quantities <- bind_list_of_draws_array(generated_quantities)
if (!is.null(generated_quantities)) {
posterior::variables(generated_quantities) <- repaired_variables
}
Expand Down Expand Up @@ -422,8 +400,8 @@ CmdStanMCMC_CSV <- R6::R6Class(
public = list(
initialize = function(csv_contents, files, check_diagnostics = TRUE) {
if (check_diagnostics) {
check_divergences(csv_contents)
check_sampler_transitions_treedepth(csv_contents)
check_divergences(csv_contents$post_warmup_sampler_diagnostics)
check_sampler_transitions_treedepth(csv_contents$post_warmup_sampler_diagnostics, csv_contents$metadata)
}
private$output_files_ <- files
private$metadata_ <- csv_contents$metadata
Expand Down Expand Up @@ -708,7 +686,20 @@ check_csv_metadata_matches <- function(a, b) {
list(not_matching = not_matching)
}


bind_list_of_draws_array <- function(draws, along = "chain") {
if (!is.null(draws) && length(draws) > 0) {
if (length(draws) > 1) {
draws <- lapply(draws, posterior::as_draws_array)
draws[["along"]] <- along
draws <- do.call(posterior::bind_draws, draws)
} else {
draws <- posterior::as_draws_array(draws[[1]])
}
} else {
draws <- NULL
}
draws
}

# convert names like beta.1.1 to beta[1,1]
repair_variable_names <- function(names) {
Expand Down
77 changes: 44 additions & 33 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -708,17 +708,12 @@ CmdStanMCMC <- R6::R6Class(
} else {
if (self$runset$args$validate_csv) {
fixed_param <- runset$args$method_args$fixed_param
csv_contents <- read_cmdstan_csv(
self$output_files(),
variables = "",
sampler_diagnostics =
if (!fixed_param) c("treedepth__", "divergent__") else ""
)
private$read_csv_(variables = "",
sampler_diagnostics = if (!fixed_param) c("treedepth__", "divergent__") else "")
if (!fixed_param) {
check_divergences(csv_contents)
check_sampler_transitions_treedepth(csv_contents)
check_divergences(private$sampler_diagnostics_)
check_sampler_transitions_treedepth(private$sampler_diagnostics_, private$metadata_)
}
private$metadata_ <- csv_contents$metadata
}
}
},
Expand Down Expand Up @@ -782,38 +777,54 @@ CmdStanMCMC <- R6::R6Class(
private$metadata_ <- csv_contents$metadata

if (!is.null(csv_contents$post_warmup_draws)) {
missing_variables <- !(posterior::variables(csv_contents$post_warmup_draws) %in% posterior::variables(private$draws_))
private$draws_ <- posterior::bind_draws(
private$draws_,
csv_contents$post_warmup_draws[,,missing_variables],
along="variable"
)
if (is.null(private$draws_)) {
private$draws_ <- csv_contents$post_warmup_draws
} else {
missing_variables <- !(posterior::variables(csv_contents$post_warmup_draws) %in% posterior::variables(private$draws_))
private$draws_ <- posterior::bind_draws(
private$draws_,
csv_contents$post_warmup_draws[,,missing_variables],
along="variable"
)
}
}
if (!is.null(csv_contents$post_warmup_sampler_diagnostics)) {
missing_variables <- !(posterior::variables(csv_contents$post_warmup_sampler_diagnostics) %in% posterior::variables(private$sampler_diagnostics_))
private$sampler_diagnostics_ <- posterior::bind_draws(
private$sampler_diagnostics_,
csv_contents$post_warmup_sampler_diagnostics[,,missing_variables],
along="variable"
)
if (is.null(private$sampler_diagnostics_)) {
private$sampler_diagnostics_ <- csv_contents$post_warmup_sampler_diagnostics
} else {
missing_variables <- !(posterior::variables(csv_contents$post_warmup_sampler_diagnostics) %in% posterior::variables(private$sampler_diagnostics_))
private$sampler_diagnostics_ <- posterior::bind_draws(
private$sampler_diagnostics_,
csv_contents$post_warmup_sampler_diagnostics[,,missing_variables],
along="variable"
)
}
}
if (!is.null(csv_contents$metadata$save_warmup)
&& csv_contents$metadata$save_warmup) {
if (!is.null(csv_contents$warmup_draws)) {
missing_variables <- !(posterior::variables(csv_contents$warmup_draws) %in% posterior::variables(private$warmup_draws_))
private$warmup_draws_ <- posterior::bind_draws(
private$warmup_draws_,
csv_contents$warmup_draws[,,missing_variables],
along="variable"
)
if (is.null(private$warmup_draws_)) {
private$warmup_draws_ <- csv_contents$warmup_draws
} else {
missing_variables <- !(posterior::variables(csv_contents$warmup_draws) %in% posterior::variables(private$warmup_draws_))
private$warmup_draws_ <- posterior::bind_draws(
private$warmup_draws_,
csv_contents$warmup_draws[,,missing_variables],
along="variable"
)
}
}
if (!is.null(csv_contents$warmup_sampler_diagnostics)) {
missing_variables <- !(posterior::variables(csv_contents$warmup_sampler_diagnostics) %in% posterior::variables(private$warmup_sampler_diagnostics_))
private$warmup_sampler_diagnostics_ <- posterior::bind_draws(
private$warmup_sampler_diagnostics_,
csv_contents$warmup_sampler_diagnostics[,,missing_variables],
along="variable"
)
if (is.null(private$warmup_sampler_diagnostics_)) {
private$warmup_sampler_diagnostics_ <- csv_contents$warmup_sampler_diagnostics
} else {
missing_variables <- !(posterior::variables(csv_contents$warmup_sampler_diagnostics) %in% posterior::variables(private$warmup_sampler_diagnostics_))
private$warmup_sampler_diagnostics_ <- posterior::bind_draws(
private$warmup_sampler_diagnostics_,
csv_contents$warmup_sampler_diagnostics[,,missing_variables],
along="variable"
)
}
}
}
invisible(self)
Expand Down
19 changes: 9 additions & 10 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ set_num_threads <- function(num_threads) {
call. = FALSE)
}

check_divergences <- function(csv_contents) {
if (!is.null(csv_contents$post_warmup_sampler_diagnostics)) {
divergences <- posterior::extract_variable_matrix(csv_contents$post_warmup_sampler_diagnostics, "divergent__")
check_divergences <- function(post_warmup_sampler_diagnostics) {
if (!is.null(post_warmup_sampler_diagnostics)) {
divergences <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "divergent__")
num_of_draws <- length(divergences)
num_of_divergences <- sum(divergences)
if (!is.na(num_of_divergences) && num_of_divergences > 0) {
Expand All @@ -274,17 +274,16 @@ check_divergences <- function(csv_contents) {
}
}

check_sampler_transitions_treedepth <- function(csv_contents) {
if (!is.null(csv_contents$post_warmup_sampler_diagnostics)) {
treedepth <- posterior::extract_variable_matrix(csv_contents$post_warmup_sampler_diagnostics, "treedepth__")
check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, metadata) {
if (!is.null(post_warmup_sampler_diagnostics)) {
treedepth <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "treedepth__")
num_of_draws <- length(treedepth)
max_treedepth <- csv_contents$metadata$max_treedepth
max_treedepth_hit <- sum(treedepth >= max_treedepth)
max_treedepth_hit <- sum(treedepth >= metadata$max_treedepth)
if (!is.na(max_treedepth_hit) && max_treedepth_hit > 0) {
percentage_max_treedepth <- (max_treedepth_hit)/num_of_draws*100
message(max_treedepth_hit, " of ", num_of_draws, " (", (format(round(percentage_max_treedepth, 0), nsmall = 1)), "%)",
" transitions hit the maximum treedepth limit of ", max_treedepth,
" or 2^", max_treedepth, "-1 leapfrog steps.\n",
" transitions hit the maximum treedepth limit of ", metadata$max_treedepth,
" or 2^", metadata$max_treedepth, "-1 leapfrog steps.\n",
"Trajectories that are prematurely terminated due to this limit will result in slow exploration.\n",
"Increasing the max_treedepth limit can avoid this at the expense of more computation.\n",
"If increasing max_treedepth does not remove warnings, try to reparameterize the model.\n")
Expand Down
24 changes: 24 additions & 0 deletions tests/testthat/test-fit-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@ test_that("draws() stops for unkown variables", {
)
})

test_that("draws() works when gradually adding variables", {
skip_on_cran()
fit <- testing_fit("logistic", method = "sample", refresh = 0,
validate_csv = TRUE, save_warmup = TRUE)

draws_lp__ <- fit$draws(variables = c("lp__"), inc_warmup = TRUE)
sampler_diagnostics <- fit$sampler_diagnostics(inc_warmup = TRUE)
expect_type(draws_lp__, "double")
expect_s3_class(draws_lp__, "draws_array")
expect_equal(posterior::variables(draws_lp__), c("lp__"))
expect_type(sampler_diagnostics, "double")
expect_s3_class(sampler_diagnostics, "draws_array")
expect_equal(posterior::variables(sampler_diagnostics), c(c("treedepth__", "divergent__", "accept_stat__", "stepsize__", "n_leapfrog__", "energy__")))
draws_alpha <- fit$draws(variables = c("alpha"), inc_warmup = TRUE)
expect_type(draws_alpha, "double")
expect_s3_class(draws_alpha, "draws_array")
expect_equal(posterior::variables(draws_alpha), c("alpha"))
draws_beta <- fit$draws(variables = c("beta"), inc_warmup = TRUE)
expect_type(draws_beta, "double")
expect_s3_class(draws_beta, "draws_array")
expect_equal(posterior::variables(draws_beta), c("beta[1]", "beta[2]", "beta[3]"))
})

test_that("draws() method returns draws_array (reading csv works)", {
skip_on_cran()
draws <- fit_mcmc$draws()
Expand Down Expand Up @@ -273,3 +296,4 @@ test_that("loo errors if it can't find log lik variables", {
fixed = TRUE
)
})

30 changes: 22 additions & 8 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ test_that("check_divergences() works", {
csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"))
csv_output <- read_cmdstan_csv(csv_files)
output <- "14 of 100 \\(14.0%\\) transitions ended with a divergence."
expect_message(check_divergences(csv_output), output)
expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output)

csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"),
test_path("resources", "csv", "model1-2-no-warmup.csv"))
csv_output <- read_cmdstan_csv(csv_files)
output <- "28 of 200 \\(14.0%\\) transitions ended with a divergence."
expect_message(check_divergences(csv_output), output)
expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output)

csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv"))
csv_output <- read_cmdstan_csv(csv_files)
output <- "1 of 100 \\(1.0%\\) transitions ended with a divergence."
expect_message(check_divergences(csv_output), output)
expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output)


fit_wramup_no_samples <- testing_fit("logistic", method = "sample",
Expand All @@ -32,27 +32,41 @@ test_that("check_divergences() works", {
save_warmup = TRUE,
validate_csv = FALSE)
csv_output <- read_cmdstan_csv(fit_wramup_no_samples$output_files())
expect_message(check_divergences(csv_output), regexp = NA)
expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), regexp = NA)
})

test_that("check_sampler_transitions_treedepth() works", {
skip_on_cran()
csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"))
csv_output <- read_cmdstan_csv(csv_files)
output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps."
expect_message(check_sampler_transitions_treedepth(csv_output), output)
expect_message(
check_sampler_transitions_treedepth(
csv_output$post_warmup_sampler_diagnostics,
csv_output$metadata),
output
)

csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"),
test_path("resources", "csv", "model1-2-no-warmup.csv"))
csv_output <- read_cmdstan_csv(csv_files)
output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps."
expect_message(check_sampler_transitions_treedepth(csv_output), output)
expect_message(
check_sampler_transitions_treedepth(
csv_output$post_warmup_sampler_diagnostics,
csv_output$metadata),
output
)

csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv"))
csv_output <- read_cmdstan_csv(csv_files)
output <- "1 of 100 \\(1.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps."
expect_message(check_sampler_transitions_treedepth(csv_output), output)

expect_message(
check_sampler_transitions_treedepth(
csv_output$post_warmup_sampler_diagnostics,
csv_output$metadata),
output
)
})

test_that("cmdstan_summary works if bin/stansummary deleted file", {
Expand Down

0 comments on commit a27d933

Please sign in to comment.