Skip to content

Commit

Permalink
Merge pull request #519 from stan-dev/add_model_variables
Browse files Browse the repository at this point in the history
Add `$variables()`
  • Loading branch information
jgabry authored Aug 17, 2021
2 parents 0872512 + fa98252 commit 58f0980
Show file tree
Hide file tree
Showing 18 changed files with 577 additions and 203 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ to gradients computed via finite differences. (#485)
so that models do not get unnecessarily recompiled when calling the function
multiple times with the same code. (#495, @martinmodrak)

* New method `$variables()` for CmdstanModel objects that returns a list of
variables in the Stan model, their types and number of dimensions. Does
not require the model to be compiled. (#519)

* `write_stan_json()` now handles data of class `"table"`. Tables are converted
to vector, matrix, or array depending on the dimensions of the table. (#528)

Expand Down
2 changes: 1 addition & 1 deletion R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ read_cmdstan_csv <- function(files,
repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables)
}
model_param_dims <- variable_dims(metadata$model_params)
metadata$stan_variable_dims <- model_param_dims
metadata$stan_variable_sizes <- model_param_dims
metadata$stan_variables <- names(model_param_dims)

if (metadata$method == "sample") {
Expand Down
68 changes: 67 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ CmdStanModel <- R6::R6Class(
include_paths_ = NULL,
precompile_cpp_options_ = NULL,
precompile_stanc_options_ = NULL,
precompile_include_paths_ = NULL
precompile_include_paths_ = NULL,
variables_ = NULL
),
public = list(
initialize = function(stan_file, compile, ...) {
Expand Down Expand Up @@ -527,6 +528,48 @@ compile <- function(quiet = TRUE,
}
CmdStanModel$set("public", name = "compile", value = compile)

#' Input and output variables of a Stan program
#'
#' @name model-method-variables
#' @aliases variables
#' @family CmdStanModel methods
#'
#' @description The `$variables()` method of a [`CmdStanModel`] object returns
#' a list, each element representing a Stan model block: `data`, `parameters`,
#' `transformed_parameters` and `generated_quantities`.
#'
#' Each element contains a list of variables, with each variables represented
#' as a list with infromation on its scalar type (`real` or `int`) and
#' number of dimensions.
#'
#' `transformed data` is not included, as variables in that block are not
#' part of the model's input or output.
#'
#' @return The `$variables()` returns a list with information on input and
#' output variables for each of the Stan model blocks.
#'
#' @examples
#' \dontrun{
#' file <- file.path(cmdstan_path(), "examples/bernoulli/bernoulli.stan")
#'
#' # create a `CmdStanModel` object, compiling the model is not required
#' mod <- cmdstan_model(file, compile = FALSE)
#'
#' mod$variables()
#'
#' }
#'
variables <- function() {
if (cmdstan_version() < "2.27") {
stop("$variables() is only supported for CmdStan 2.27 or newer.", call. = FALSE)
}
if (is.null(private$variables_)) {
private$variables_ <- model_variables(self$stan_file())
}
private$variables_
}
CmdStanModel$set("public", name = "variables", value = variables)

#' Check syntax of a Stan program
#'
#' @name model-method-check_syntax
Expand Down Expand Up @@ -1387,4 +1430,27 @@ include_paths_stanc3_args <- function(include_paths = NULL) {
stancflags <- paste0(stancflags, include_paths_flag, include_paths)
}
stancflags
}

model_variables <- function(stan_file) {
out_file <- tempfile(fileext = ".json")
run_log <- processx::run(
command = stanc_cmd(),
args = c(stan_file, "--info"),
wd = cmdstan_path(),
echo = FALSE,
echo_cmd = FALSE,
stdout = out_file,
error_on_status = TRUE
)
variables <- jsonlite::read_json(out_file, na = "null")
variables$data <- variables$inputs
variables$inputs <- NULL
variables$transformed_parameters <- variables[["transformed parameters"]]
variables[["transformed parameters"]] <- NULL
variables$generated_quantities <- variables[["generated quantities"]]
variables[["generated quantities"]] <- NULL
variables$functions <- NULL
variables$distributions <- NULL
variables
}
481 changes: 291 additions & 190 deletions docs/articles/cmdstanr-internals.html

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Pandoc 2.9 adds attributes on both header and div. We remove the former (to
// be compatible with the behavior of Pandoc < 2.8).
document.addEventListener('DOMContentLoaded', function(e) {
var hs = document.querySelectorAll("div.section[class*='level'] > :first-child");
var i, h, a;
for (i = 0; i < hs.length; i++) {
h = hs[i];
if (!/^h[1-6]$/i.test(h.tagName)) continue; // it should be a header h1-h6
a = h.attributes;
while (a.length > 0) h.removeAttribute(a[0].name);
}
});
4 changes: 2 additions & 2 deletions docs/articles/index.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-check_syntax.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-compile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-diagnose.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-generate-quantities.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-sample_mpi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 49 additions & 0 deletions man/model-method-variables.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/model-method-variational.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions tests/testthat/test-csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ test_that("read_cmdstan_csv() errors for files from different methods", {
)
})

test_that("stan_variables and stan_variable_dims works in read_cdmstan_csv()", {
test_that("stan_variables and stan_variable_sizes works in read_cdmstan_csv()", {
skip_on_cran()
bern_opt <- read_cmdstan_csv(fit_bernoulli_optimize$output_files())
bern_vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files())
Expand All @@ -462,15 +462,15 @@ test_that("stan_variables and stan_variable_dims works in read_cdmstan_csv()", {

expect_equal(gq$metadata$stan_variables, c("y_rep","sum_y"))

expect_equal(bern_opt$metadata$stan_variable_dims, list(lp__ = 1, theta = 1))
expect_equal(bern_vi$metadata$stan_variable_dims, list(lp__ = 1, lp_approx__ = 1, theta = 1))
expect_equal(bern_samp$metadata$stan_variable_dims, list(lp__ = 1, theta = 1))
expect_equal(bern_opt$metadata$stan_variable_sizes, list(lp__ = 1, theta = 1))
expect_equal(bern_vi$metadata$stan_variable_sizes, list(lp__ = 1, lp_approx__ = 1, theta = 1))
expect_equal(bern_samp$metadata$stan_variable_sizes, list(lp__ = 1, theta = 1))

expect_equal(log_opt$metadata$stan_variable_dims, list(lp__ = 1, alpha = 1, beta = 3))
expect_equal(log_vi$metadata$stan_variable_dims, list(lp__ = 1, lp_approx__ = 1, alpha = 1, beta = 3))
expect_equal(log_samp$metadata$stan_variable_dims, list(lp__ = 1, alpha = 1, beta = 3))
expect_equal(log_opt$metadata$stan_variable_sizes, list(lp__ = 1, alpha = 1, beta = 3))
expect_equal(log_vi$metadata$stan_variable_sizes, list(lp__ = 1, lp_approx__ = 1, alpha = 1, beta = 3))
expect_equal(log_samp$metadata$stan_variable_sizes, list(lp__ = 1, alpha = 1, beta = 3))

expect_equal(gq$metadata$stan_variable_dims, list(y_rep = 10, sum_y = 1))
expect_equal(gq$metadata$stan_variable_sizes, list(y_rep = 10, sum_y = 1))
})

test_that("returning time works for read_cmdstan_csv", {
Expand Down
77 changes: 77 additions & 0 deletions tests/testthat/test-model-variables.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
context("model-compile")

if (not_on_cran()) {
set_cmdstan_path()
}

test_that("$variables() work correctly with example models", {
skip_on_cran()
mod <- testing_model("bernoulli")
expect_equal(names(mod$variables()$data), c("N", "y"))
expect_equal(names(mod$variables()$parameters), c("theta"))
expect_equal(mod$variables()$data$N$type, "int")
expect_equal(mod$variables()$data$N$dimensions, 0)
expect_equal(mod$variables()$data$y$type, "int")
expect_equal(mod$variables()$data$y$dimensions, 1)
expect_equal(mod$variables()$parameters$theta$type, "real")
expect_equal(mod$variables()$parameters$theta$dimensions, 0)
expect_equal(length(mod$variables()$transformed_parameters), 0)
expect_equal(length(mod$variables()$generated_quantities), 0)
expect_true(is.list(mod$variables()$transformed_parameters))
expect_true(is.list(mod$variables()$generated_quantities))

mod <- testing_model("bernoulli_log_lik")
expect_equal(names(mod$variables()$data), c("N", "y"))
expect_equal(names(mod$variables()$parameters), c("theta"))
expect_equal(names(mod$variables()$generated_quantities), c("log_lik"))
expect_equal(mod$variables()$generated_quantities$log_lik$type, "real")
expect_equal(mod$variables()$generated_quantities$log_lik$dimensions, 1)

mod <- testing_model("logistic")
expect_equal(names(mod$variables()$data), c("N", "K", "y", "X"))
expect_equal(names(mod$variables()$parameters), c("alpha", "beta"))
expect_equal(mod$variables()$data$N$type, "int")
expect_equal(mod$variables()$data$N$dimensions, 0)
expect_equal(mod$variables()$data$K$type, "int")
expect_equal(mod$variables()$data$K$dimensions, 0)
expect_equal(mod$variables()$data$y$type, "int")
expect_equal(mod$variables()$data$y$dimensions, 1)
expect_equal(mod$variables()$data$X$type, "real")
expect_equal(mod$variables()$data$X$dimensions, 2)
expect_equal(mod$variables()$parameters$alpha$type, "real")
expect_equal(mod$variables()$parameters$alpha$dimensions, 0)
expect_equal(mod$variables()$parameters$beta$type, "real")
expect_equal(mod$variables()$parameters$beta$dimensions, 1)
})

test_that("$variables() work correctly with example models", {
skip_on_cran()
code <- "
data {
array[1,2,3,4,5,6,7,8] int y;
array[1,2,3,4] vector[4] x;
}
parameters {
real z;
}
transformed parameters {
array[1,2,3] real p;
array[2] matrix[2,3] pp;
}
"
stan_file <- write_stan_file(code)
mod <- cmdstan_model(stan_file)
expect_equal(names(mod$variables()$data), c("y", "x"))
expect_equal(names(mod$variables()$parameters), c("z"))
expect_equal(names(mod$variables()$transformed_parameters), c("p", "pp"))
expect_equal(mod$variables()$data$y$type, "int")
expect_equal(mod$variables()$data$y$dimensions, 8)
expect_equal(mod$variables()$data$x$type, "real")
expect_equal(mod$variables()$data$x$dimensions, 5)
expect_equal(mod$variables()$parameters$z$type, "real")
expect_equal(mod$variables()$parameters$z$dimensions, 0)
expect_equal(mod$variables()$transformed_parameters$p$type, "real")
expect_equal(mod$variables()$transformed_parameters$p$dimensions, 3)
expect_equal(mod$variables()$transformed_parameters$pp$type, "real")
expect_equal(mod$variables()$transformed_parameters$pp$dimensions, 3)
})
Loading

0 comments on commit 58f0980

Please sign in to comment.