Skip to content

Commit

Permalink
Merge pull request #746 from stan-dev/stdout-and-stderr
Browse files Browse the repository at this point in the history
Update handling of show_messages, add show_exceptions
  • Loading branch information
andrjohns authored Mar 24, 2023
2 parents 8bb211a + d24a359 commit a30d3e9
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 11 deletions.
8 changes: 6 additions & 2 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,7 @@ sample <- function(data = NULL,
window = NULL,
fixed_param = FALSE,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
# deprecated
cores = NULL,
Expand Down Expand Up @@ -1124,7 +1125,8 @@ sample <- function(data = NULL,
num_procs = checkmate::assert_integerish(chains, lower = 1, len = 1),
parallel_procs = checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE),
threads_per_proc = assert_valid_threads(threads_per_chain, self$cpp_options(), multiple_chains = TRUE),
show_stderr_messages = show_messages
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
)
model_variables <- NULL
if (is_variables_method_supported(self)) {
Expand Down Expand Up @@ -1261,6 +1263,7 @@ sample_mpi <- function(data = NULL,
fixed_param = FALSE,
sig_figs = NULL,
show_messages = TRUE,
show_exceptions = TRUE,
diagnostics = c("divergences", "treedepth", "ebfmi"),
# deprecated
validate_csv = TRUE) {
Expand All @@ -1283,7 +1286,8 @@ sample_mpi <- function(data = NULL,
procs <- CmdStanMCMCProcs$new(
num_procs = checkmate::assert_integerish(chains, lower = 1, len = 1),
parallel_procs = 1,
show_stderr_messages = show_messages
show_stderr_messages = show_exceptions,
show_stdout_messages = show_messages
)
model_variables <- NULL
if (is_variables_method_supported(self)) {
Expand Down
36 changes: 29 additions & 7 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,13 @@ check_target_exe <- function(exe) {
}
}
if (is.null(procs$threads_per_proc())) {
cat(paste0(start_msg, "...\n\n"))
if (procs$show_stdout_messages()) {
cat(paste0(start_msg, "...\n\n"))
}
} else {
cat(paste0(start_msg, ", with ", procs$threads_per_proc(), " thread(s) per chain...\n\n"))
if (procs$show_stdout_messages()) {
cat(paste0(start_msg, ", with ", procs$threads_per_proc(), " thread(s) per chain...\n\n"))
}
Sys.setenv("STAN_NUM_THREADS" = as.integer(procs$threads_per_proc()))
# Windows environment variables have to be explicitly exported to WSL
if (os_is_wsl()) {
Expand Down Expand Up @@ -425,9 +429,13 @@ CmdStanRun$set("private", name = "run_sample_", value = .run_sample)
}
}
if (is.null(procs$threads_per_proc())) {
cat(paste0(start_msg, "...\n\n"))
if (procs$show_stdout_messages()) {
cat(paste0(start_msg, "...\n\n"))
}
} else {
cat(paste0(start_msg, ", with ", procs$threads_per_proc(), " thread(s) per chain...\n\n"))
if (procs$show_stdout_messages()) {
cat(paste0(start_msg, ", with ", procs$threads_per_proc(), " thread(s) per chain...\n\n"))
}
Sys.setenv("STAN_NUM_THREADS" = as.integer(procs$threads_per_proc()))
# Windows environment variables have to be explicitly exported to WSL
if (os_is_wsl()) {
Expand Down Expand Up @@ -612,6 +620,12 @@ CmdStanProcs <- R6::R6Class(
private$show_stdout_messages_ <- show_stdout_messages
invisible(self)
},
show_stdout_messages = function () {
private$show_stdout_messages_
},
show_stderr_messages = function () {
private$show_stderr_messages_
},
num_procs = function() {
private$num_procs_
},
Expand Down Expand Up @@ -927,7 +941,7 @@ CmdStanMCMCProcs <- R6::R6Class(
|| grepl("stancflags", line, fixed = TRUE)) {
ignore_line <- TRUE
}
if ((state > 1.5 && state < 5 && !ignore_line) || is_verbose_mode()) {
if ((state > 1.5 && state < 5 && !ignore_line && private$show_stdout_messages_) || is_verbose_mode()) {
if (state == 2) {
message("Chain ", id, " ", line)
} else {
Expand All @@ -939,7 +953,9 @@ CmdStanMCMCProcs <- R6::R6Class(
if (state == 1) {
state <- 2;
}
message("Chain ", id, " ", line)
if (private$show_stderr_messages_) {
message("Chain ", id, " ", line)
}
}
private$proc_state_[[id]] <- next_state
} else {
Expand All @@ -951,6 +967,9 @@ CmdStanMCMCProcs <- R6::R6Class(
invisible(self)
},
report_time = function(id = NULL) {
if (!private$show_stdout_messages_) {
return(invisible(NULL))
}
if (!is.null(id)) {
if (self$proc_state(id) == 7) {
warning("Chain ", id, " finished unexpectedly!\n", immediate. = TRUE, call. = FALSE)
Expand Down Expand Up @@ -1030,7 +1049,7 @@ CmdStanGQProcs <- R6::R6Class(
if (nzchar(line)) {
if (self$proc_state(id) == 1 && grepl("refresh = ", line, perl = TRUE)) {
self$set_proc_state(id, new_state = 1.5)
} else if (self$proc_state(id) >= 2) {
} else if (self$proc_state(id) >= 2 && private$show_stdout_messages_) {
cat("Chain", id, line, "\n")
}
} else {
Expand All @@ -1044,6 +1063,9 @@ CmdStanGQProcs <- R6::R6Class(
invisible(self)
},
report_time = function(id = NULL) {
if (!private$show_stdout_messages_) {
return(invisible(NULL))
}
if (!is.null(id)) {
if (self$proc_state(id) == 7) {
warning("Chain ", id, " finished unexpectedly!\n", immediate. = TRUE, call. = FALSE)
Expand Down
4 changes: 4 additions & 0 deletions man-roxygen/model-sample-args.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
#' `fixed_param=TRUE` is mandatory. When `fixed_param=TRUE` the `chains` and
#' `parallel_chains` arguments will be set to `1`.
#' @param show_messages (logical) When `TRUE` (the default), prints all
#' output during the sampling process, such as iteration numbers and elapsed times.
#' If the output is silenced then the [`$output()`][fit-method-output] method of
#' the resulting fit object can be used to display the silenced messages.
#' @param show_exceptions (logical) When `TRUE` (the default), prints all
#' informational messages, for example rejection of the current proposal.
#' Disable if you wish to silence these messages, but this is not usually
#' recommended unless you are very confident that the model is correct up to
Expand Down
6 changes: 6 additions & 0 deletions man/model-method-sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/model-method-sample_mpi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 43 additions & 2 deletions tests/testthat/test-model-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ test_that("sample() method runs when the stan file is removed", {
)
})

test_that("sample() prints informational messages depening on show_messages", {
test_that("sample() prints informational messages depening on show_exceptions", {
mod_info_msg <- testing_model("info_message")
expect_sample_output(
expect_message(
Expand All @@ -127,7 +127,7 @@ test_that("sample() prints informational messages depening on show_messages", {
)
)
expect_sample_output(
expect_message(mod_info_msg$sample(show_messages = FALSE), regexp = NA)
expect_message(mod_info_msg$sample(show_exceptions = FALSE), regexp = NA)
)
})

Expand Down Expand Up @@ -321,3 +321,44 @@ test_that("sig_figs warning if version less than 2.25", {
)
reset_cmdstan_version()
})

test_that("Errors are suppressed with show_exceptions", {
errmodcode <- "
data {
real y_mean;
}
transformed data {
vector[1] small;
small[2] = 1.0;
}
parameters {
real y;
}
model {
y ~ normal(y_mean, 1);
}
"
errmod <- cmdstan_model(write_stan_file(errmodcode), force_recompile = TRUE)

expect_message(
suppressWarnings(errmod$sample(data = list(y_mean = 1), chains = 1)),
"Chain 1 Exception: vector[uni] assign: accessing element out of range",
fixed = TRUE
)

expect_no_message(
suppressWarnings(errmod$sample(data = list(y_mean = 1), chains = 1, show_exceptions = FALSE))
)
})

test_that("All output can be suppressed by show_messages", {
stan_program <- testing_stan_file("bernoulli")
data_list <- testing_data("bernoulli")
mod <- cmdstan_model(stan_program, force_recompile = TRUE)
options("cmdstanr_verbose" = FALSE)
output <- capture.output(
fit <- mod$sample(data = data_list, show_messages = FALSE)
)

expect_length(output, 0)
})

0 comments on commit a30d3e9

Please sign in to comment.