Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CmdStanMCMC/MLE/VB from read_cmdstan_csv ouput #412

Merged
merged 19 commits into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: cmdstanr
Title: R Interface to 'CmdStan'
Version: 0.3.0
Version: 0.3.0.9000
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bumping to dev version

Date: 2020-12-17
Authors@R:
c(person(given = "Jonah", family = "Gabry", role = c("aut", "cre"),
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Generated by roxygen2: do not edit by hand

export(as_cmdstan_mcmc)
export(as_cmdstan_mle)
export(as_cmdstan_vb)
export(check_cmdstan_toolchain)
export(cmdstan_default_install_path)
export(cmdstan_default_path)
Expand Down
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# cmdstanr 0.3.0.9000

### Bug fixes

### New features

* New functions `as_cmdstan_mcmc()`, `as_cmdstan_mle()`, `as_cmdstan_vb()` that
create CmdStanMCMC/MLE/VB objects directly from CmdStan CSV files. (#412)

# cmdstanr 0.3.0

### Bug fixes
Expand Down
78 changes: 36 additions & 42 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ CmdStanFit <- R6::R6Class(
self$runset$num_procs()
},
print = function(variables = NULL, ..., digits = 2, max_rows = 10) {
if (!length(self$output_files(include_failed = FALSE))) {
stop("Fitting failed. Unable to print.", call. = FALSE)
}
# filter variables before passing to summary to avoid computing anything
# that won't be printed because of max_rows
all_variables <- self$metadata()$model_params
Expand Down Expand Up @@ -330,7 +327,7 @@ CmdStanFit$set("public", name = "lp", value = lp)
#'
summary <- function(variables = NULL, ...) {
draws <- self$draws(variables)
if (self$runset$method() == "sample") {
if (self$metadata()$method == "sample") {
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
summary <- posterior::summarise_draws(draws, ...)
} else {
if (!length(list(...))) {
Expand All @@ -345,7 +342,7 @@ summary <- function(variables = NULL, ...) {
summary <- posterior::summarise_draws(draws, ...)
}
}
if (self$runset$method() == "optimize") {
if (self$metadata()$method == "optimize") {
summary <- summary[, c("variable", "mean")]
colnames(summary) <- c("variable", "estimate")
}
Expand Down Expand Up @@ -599,10 +596,10 @@ CmdStanFit$set("public", name = "output", value = output)
#' }
#'
metadata <- function() {
if (!length(self$output_files(include_failed = FALSE))) {
stop("Fitting failed. Unable to retrieve the metadata.", call. = FALSE)
}
if (is.null(private$metadata_)) {
if (!length(self$output_files(include_failed = FALSE))) {
stop("Fitting failed. Unable to retrieve the metadata.", call. = FALSE)
}
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
private$read_csv_()
}
private$metadata_
Expand Down Expand Up @@ -705,17 +702,17 @@ CmdStanMCMC <- R6::R6Class(
} else {
if (self$runset$args$validate_csv) {
fixed_param <- runset$args$method_args$fixed_param
data_csv <- read_cmdstan_csv(
csv_contents <- read_cmdstan_csv(
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
self$output_files(),
variables = "",
sampler_diagnostics =
if (!fixed_param) c("treedepth__", "divergent__") else ""
)
if (!fixed_param) {
check_divergences(data_csv)
check_sampler_transitions_treedepth(data_csv)
check_divergences(csv_contents)
check_sampler_transitions_treedepth(csv_contents)
}
private$metadata_ <- data_csv$metadata
private$metadata_ <- csv_contents$metadata
}
}
},
Expand All @@ -730,9 +727,6 @@ CmdStanMCMC <- R6::R6Class(

# override the CmdStanFit draws method
draws = function(variables = NULL, inc_warmup = FALSE) {
if (!length(self$output_files(include_failed = FALSE))) {
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
stop("No chains finished successfully. Unable to retrieve the draws.", call. = FALSE)
}
if (inc_warmup && !private$metadata_$save_warmup) {
stop("Warmup draws were requested from a fit object without them! ",
"Please rerun the model with save_warmup = TRUE.", call. = FALSE)
Expand Down Expand Up @@ -773,45 +767,45 @@ CmdStanMCMC <- R6::R6Class(
if (!length(self$output_files(include_failed = FALSE))) {
stop("No chains finished successfully. Unable to retrieve the draws.", call. = FALSE)
}
data_csv <- read_cmdstan_csv(
csv_contents <- read_cmdstan_csv(
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
files = self$output_files(include_failed = FALSE),
variables = variables,
sampler_diagnostics = sampler_diagnostics
)
private$inv_metric_ <- data_csv$inv_metric
private$metadata_ <- data_csv$metadata
private$inv_metric_ <- csv_contents$inv_metric
private$metadata_ <- csv_contents$metadata

if (!is.null(data_csv$post_warmup_draws)) {
missing_variables <- !(posterior::variables(data_csv$post_warmup_draws) %in% posterior::variables(private$draws_))
if (!is.null(csv_contents$post_warmup_draws)) {
missing_variables <- !(posterior::variables(csv_contents$post_warmup_draws) %in% posterior::variables(private$draws_))
private$draws_ <- posterior::bind_draws(
private$draws_,
data_csv$post_warmup_draws[,,missing_variables],
csv_contents$post_warmup_draws[,,missing_variables],
along="variable"
)
}
if (!is.null(data_csv$post_warmup_sampler_diagnostics)) {
missing_variables <- !(posterior::variables(data_csv$post_warmup_sampler_diagnostics) %in% posterior::variables(private$sampler_diagnostics_))
if (!is.null(csv_contents$post_warmup_sampler_diagnostics)) {
missing_variables <- !(posterior::variables(csv_contents$post_warmup_sampler_diagnostics) %in% posterior::variables(private$sampler_diagnostics_))
private$sampler_diagnostics_ <- posterior::bind_draws(
private$sampler_diagnostics_,
data_csv$post_warmup_sampler_diagnostics[,,missing_variables],
csv_contents$post_warmup_sampler_diagnostics[,,missing_variables],
along="variable"
)
}
if (!is.null(data_csv$metadata$save_warmup)
&& data_csv$metadata$save_warmup) {
if (!is.null(data_csv$warmup_draws)) {
missing_variables <- !(posterior::variables(data_csv$warmup_draws) %in% posterior::variables(private$warmup_draws_))
if (!is.null(csv_contents$metadata$save_warmup)
&& csv_contents$metadata$save_warmup) {
if (!is.null(csv_contents$warmup_draws)) {
missing_variables <- !(posterior::variables(csv_contents$warmup_draws) %in% posterior::variables(private$warmup_draws_))
private$warmup_draws_ <- posterior::bind_draws(
private$warmup_draws_,
data_csv$warmup_draws[,,missing_variables],
csv_contents$warmup_draws[,,missing_variables],
along="variable"
)
}
if (!is.null(data_csv$warmup_sampler_diagnostics)) {
missing_variables <- !(posterior::variables(data_csv$warmup_sampler_diagnostics) %in% posterior::variables(private$warmup_sampler_diagnostics_))
if (!is.null(csv_contents$warmup_sampler_diagnostics)) {
missing_variables <- !(posterior::variables(csv_contents$warmup_sampler_diagnostics) %in% posterior::variables(private$warmup_sampler_diagnostics_))
private$warmup_sampler_diagnostics_ <- posterior::bind_draws(
private$warmup_sampler_diagnostics_,
data_csv$warmup_sampler_diagnostics[,,missing_variables],
csv_contents$warmup_sampler_diagnostics[,,missing_variables],
along="variable"
)
}
Expand Down Expand Up @@ -1052,9 +1046,9 @@ CmdStanMLE <- R6::R6Class(
if (!length(self$output_files(include_failed = FALSE))) {
stop("Optimization failed. Unable to retrieve the draws.", call. = FALSE)
}
optim_output <- read_cmdstan_csv(self$output_files())
private$draws_ <- optim_output$point_estimates
private$metadata_ <- optim_output$metadata
csv_contents <- read_cmdstan_csv(self$output_files())
private$draws_ <- csv_contents$point_estimates
private$metadata_ <- csv_contents$metadata
invisible(self)
}
)
Expand Down Expand Up @@ -1156,9 +1150,9 @@ CmdStanVB <- R6::R6Class(
if (!length(self$output_files(include_failed = FALSE))) {
stop("Variational inference failed. Unable to retrieve the draws.", call. = FALSE)
}
vb_output <- read_cmdstan_csv(self$output_files())
private$draws_ <- vb_output$draws
private$metadata_ <- vb_output$metadata
csv_contents <- read_cmdstan_csv(self$output_files())
private$draws_ <- csv_contents$draws
private$metadata_ <- csv_contents$metadata
invisible(self)
}
)
Expand Down Expand Up @@ -1270,17 +1264,17 @@ CmdStanGQ <- R6::R6Class(
if (!length(self$output_files(include_failed = FALSE))) {
stop("Generating quantities for all MCMC chains failed. Unable to retrieve the generated quantities.", call. = FALSE)
}
data_csv <- read_cmdstan_csv(
csv_contents <- read_cmdstan_csv(
files = self$output_files(include_failed = FALSE),
variables = variables,
sampler_diagnostics = ""
)
private$metadata_ <- data_csv$metadata
if (!is.null(data_csv$generated_quantities)) {
private$metadata_ <- csv_contents$metadata
if (!is.null(csv_contents$generated_quantities)) {
private$draws_ <-
posterior::bind_draws(
private$draws_,
data_csv$generated_quantities,
csv_contents$generated_quantities,
along="variable"
)
}
Expand Down
Loading