From e2ef545878b89cc1d826cc6424c99016721a21e0 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 17 Oct 2022 10:35:57 +0300 Subject: [PATCH 1/5] Simplify construction/exposing of parameter skeleton --- R/fit.R | 85 +++++++++++++++++------------ R/utils.R | 37 +++++++++---- inst/include/model_methods.cpp | 2 +- tests/testthat/test-model-methods.R | 9 +++ 4 files changed, 86 insertions(+), 47 deletions(-) diff --git a/R/fit.R b/R/fit.R index 713106254..8883f0414 100644 --- a/R/fit.R +++ b/R/fit.R @@ -8,21 +8,20 @@ CmdStanFit <- R6::R6Class( classname = "CmdStanFit", public = list( - runset = NULL, functions = NULL, initialize = function(runset) { checkmate::assert_r6(runset, classes = "CmdStanRun") - self$runset <- runset + private$runset_ <- runset private$model_methods_env_ <- runset$model_methods_env() self$functions <- runset$standalone_env() if (!is.null(private$model_methods_env_$model_ptr)) { - initialize_model_pointer(private$model_methods_env_, self$data_file(), 0) + initialize_model_env(private$model_methods_env_, self$data_file(), 0) } invisible(self) }, num_procs = function() { - self$runset$num_procs() + private$runset_$num_procs() }, print = function(variables = NULL, ..., digits = 2, max_rows = getOption("cmdstanr_max_rows", 10)) { if (is.null(private$draws_) && @@ -68,14 +67,16 @@ CmdStanFit <- R6::R6Class( expose_functions = function(global = FALSE, verbose = FALSE) { expose_functions(self$functions, global, verbose) invisible(NULL) - } + }, + variables = function() { private$runset_$args$model_variables } ), private = list( draws_ = NULL, metadata_ = NULL, init_ = NULL, profiles_ = NULL, - model_methods_env_ = NULL + model_methods_env_ = NULL, + runset_ = NULL ) ) @@ -318,7 +319,11 @@ init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) { if (is.null(private$model_methods_env_$model_ptr)) { expose_model_methods(private$model_methods_env_, verbose, hessian) } - initialize_model_pointer(private$model_methods_env_, self$data_file(), seed) + initialize_model_env(private$model_methods_env_, self$data_file(), seed) + private$runset_$args$model_variables <- + add_param_model_sizes(private$runset_$args$model_variables, + private$model_methods_env_$param_sizes_) + invisible(NULL) } CmdStanFit$set("public", name = "init_model_methods", value = init_model_methods) @@ -435,7 +440,7 @@ unconstrain_pars <- function(pars) { stop("The method has not been compiled, please call `init_model_methods()` first", call. = FALSE) } - model_par_names <- names(self$runset$args$model_variables$parameters) + model_par_names <- names(private$runset_$args$model_variables$parameters) prov_par_names <- names(pars) model_pars_not_prov <- which(!(model_par_names %in% prov_par_names)) @@ -447,7 +452,7 @@ unconstrain_pars <- function(pars) { # Ignore extraneous parameters model_pars_only <- pars[model_par_names] - stan_pars <- process_init_list(list(pars), num_procs = 1, self$runset$args$model_variables) + stan_pars <- process_init_list(list(pars), num_procs = 1, private$runset_$args$model_variables) private$model_methods_env_$unconstrain_pars(private$model_methods_env_$model_ptr_, stan_pars) } CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars) @@ -471,24 +476,32 @@ CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars) #' fit_mcmc$constrain_pars(upars = c(0.5, 1.2, 1.1, 2.2, 1.1)) #' } #' -constrain_pars <- function(upars, transformed_parameters = TRUE, generated_quantities = TRUE) { +constrain_pars <- function(upars = NULL, transformed_parameters = TRUE, + generated_quantities = TRUE, + skeleton_only = FALSE) { if (is.null(private$model_methods_env_$model_ptr)) { stop("The method has not been compiled, please call `init_model_methods()` first", call. = FALSE) } + + skeleton <- create_skeleton(private$runset_$args$model_variables, + transformed_parameters, + generated_quantities) + + if (skeleton_only) { + return(skeleton) + } + if (length(upars) != private$model_methods_env_$num_upars_) { stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ", length(upars), " were provided!", call. = FALSE) } + cpars <- private$model_methods_env_$constrain_pars( private$model_methods_env_$model_ptr_, private$model_methods_env_$model_rng_, upars, transformed_parameters, generated_quantities) - skeleton <- create_skeleton(private$model_methods_env_$param_metadata_, - self$runset$args$model_variables, - transformed_parameters, - generated_quantities) utils::relist(cpars, skeleton) } CmdStanFit$set("public", name = "constrain_pars", value = constrain_pars) @@ -652,13 +665,13 @@ CmdStanFit$set("public", name = "summary", value = summary) #' } #' cmdstan_summary <- function(flags = NULL) { - self$runset$run_cmdstan_tool("stansummary", flags = flags) + private$runset_$run_cmdstan_tool("stansummary", flags = flags) } CmdStanFit$set("public", name = "cmdstan_summary", value = cmdstan_summary) #' @rdname fit-method-cmdstan_summary cmdstan_diagnose <- function() { - self$runset$run_cmdstan_tool("diagnose") + private$runset_$run_cmdstan_tool("diagnose") } CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose) @@ -734,7 +747,7 @@ save_output_files <- function(dir = ".", basename = NULL, timestamp = TRUE, random = TRUE) { - self$runset$save_output_files(dir, basename, timestamp, random) + private$runset_$save_output_files(dir, basename, timestamp, random) } CmdStanFit$set("public", name = "save_output_files", value = save_output_files) @@ -743,7 +756,7 @@ save_latent_dynamics_files <- function(dir = ".", basename = NULL, timestamp = TRUE, random = TRUE) { - self$runset$save_latent_dynamics_files(dir, basename, timestamp, random) + private$runset_$save_latent_dynamics_files(dir, basename, timestamp, random) } CmdStanFit$set("public", name = "save_latent_dynamics_files", value = save_latent_dynamics_files) @@ -752,7 +765,7 @@ save_profile_files <- function(dir = ".", basename = NULL, timestamp = TRUE, random = TRUE) { - self$runset$save_profile_files(dir, basename, timestamp, random) + private$runset_$save_profile_files(dir, basename, timestamp, random) } CmdStanFit$set("public", name = "save_profile_files", value = save_profile_files) @@ -761,7 +774,7 @@ save_data_file <- function(dir = ".", basename = NULL, timestamp = TRUE, random = TRUE) { - self$runset$save_data_file(dir, basename, timestamp, random) + private$runset_$save_data_file(dir, basename, timestamp, random) } CmdStanFit$set("public", name = "save_data_file", value = save_data_file) @@ -769,25 +782,25 @@ CmdStanFit$set("public", name = "save_data_file", value = save_data_file) #' @param include_failed (logical) Should CmdStan runs that failed also be #' included? The default is `FALSE.` output_files <- function(include_failed = FALSE) { - self$runset$output_files(include_failed) + private$runset_$output_files(include_failed) } CmdStanFit$set("public", name = "output_files", value = output_files) #' @rdname fit-method-save_output_files profile_files <- function(include_failed = FALSE) { - self$runset$profile_files(include_failed) + private$runset_$profile_files(include_failed) } CmdStanFit$set("public", name = "profile_files", value = profile_files) #' @rdname fit-method-save_output_files latent_dynamics_files <- function(include_failed = FALSE) { - self$runset$latent_dynamics_files(include_failed) + private$runset_$latent_dynamics_files(include_failed) } CmdStanFit$set("public", name = "latent_dynamics_files", value = latent_dynamics_files) #' @rdname fit-method-save_output_files data_file <- function() { - self$runset$data_file() + private$runset_$data_file() } CmdStanFit$set("public", name = "data_file", value = data_file) @@ -823,7 +836,7 @@ CmdStanFit$set("public", name = "data_file", value = data_file) #' } #' time <- function() { - self$runset$time() + private$runset_$time() } CmdStanFit$set("public", name = "time", value = time) @@ -861,7 +874,7 @@ CmdStanFit$set("public", name = "time", value = time) output <- function(id = NULL) { # MCMC has separate implementation but doc is shared # Non-MCMC fit is obtained with one process only so id is ignored - cat(paste(self$runset$procs$proc_output(1), collapse = "\n")) + cat(paste(private$runset_$procs$proc_output(1), collapse = "\n")) } CmdStanFit$set("public", name = "output", value = output) @@ -921,7 +934,7 @@ CmdStanFit$set("public", name = "metadata", value = metadata) #' } #' return_codes <- function() { - self$runset$procs$return_codes() + private$runset_$procs$return_codes() } CmdStanFit$set("public", name = "return_codes", value = return_codes) @@ -1004,7 +1017,7 @@ CmdStanFit$set("public", name = "profiles", value = profiles) #' } #' code <- function() { - stan_code <- self$runset$stan_code() + stan_code <- private$runset_$stan_code() if (is.null(stan_code)) { warning("'$code()' will return NULL because the 'CmdStanModel' was not created with a Stan file.", call. = FALSE) } @@ -1082,7 +1095,7 @@ CmdStanMCMC <- R6::R6Class( if (runset$args$method_args$fixed_param) { private$read_csv_(variables = "", sampler_diagnostics = "") } else { - diagnostics <- self$runset$args$method_args$diagnostics + diagnostics <- private$runset_$args$method_args$diagnostics private$read_csv_( variables = "", sampler_diagnostics = convert_hmc_diagnostic_names(diagnostics) @@ -1094,9 +1107,9 @@ CmdStanMCMC <- R6::R6Class( # override the CmdStanFit output method output = function(id = NULL) { if (is.null(id)) { - self$runset$procs$proc_output() + private$runset_$procs$proc_output() } else { - cat(paste(self$runset$procs$proc_output(id), collapse = "\n")) + cat(paste(private$runset_$procs$proc_output(id), collapse = "\n")) } }, @@ -1697,7 +1710,7 @@ CmdStanGQ <- R6::R6Class( inherit = CmdStanFit, public = list( fitted_params_files = function() { - self$runset$args$method_args$fitted_params + private$runset_$args$method_args$fitted_params }, num_chains = function() { super$num_procs() @@ -1736,9 +1749,9 @@ CmdStanGQ <- R6::R6Class( # override CmdStanFit output method output = function(id = NULL) { if (is.null(id)) { - self$runset$procs$proc_output() + private$runset_$procs$proc_output() } else { - cat(paste(self$runset$procs$proc_output(id), collapse = "\n")) + cat(paste(private$runset_$procs$proc_output(id), collapse = "\n")) } } ), @@ -1806,8 +1819,8 @@ CmdStanDiagnose <- R6::R6Class( runset = NULL, initialize = function(runset) { checkmate::assert_r6(runset, classes = "CmdStanRun") - self$runset <- runset - csv_data <- read_cmdstan_csv(self$runset$output_files()) + private$runset_ <- runset + csv_data <- read_cmdstan_csv(private$runset_$output_files()) private$metadata_ <- csv_data$metadata private$gradients_ <- csv_data$gradients private$lp_ <- csv_data$lp diff --git a/R/utils.R b/R/utils.R index b1057e07b..32414d22d 100644 --- a/R/utils.R +++ b/R/utils.R @@ -597,29 +597,46 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) { invisible(NULL) } -initialize_model_pointer <- function(env, data, seed = 0) { +initialize_model_env <- function(env, data, seed = 0) { ptr_and_rng <- env$model_ptr(data, seed) env$model_ptr_ <- ptr_and_rng$model_ptr env$model_rng_ <- ptr_and_rng$base_rng env$num_upars_ <- env$get_num_upars(env$model_ptr_) - env$param_metadata_ <- env$get_param_metadata(env$model_ptr_) + env$param_sizes_ <- env$get_param_sizes(env$model_ptr_) invisible(NULL) } -create_skeleton <- function(param_metadata, model_variables, +add_param_model_sizes <- function(model_variables, param_sizes) { + lapply(model_variables, function(block) { + item_names <- names(block) + item_in_metadata <- any(item_names %in% names(param_sizes)) + if (item_in_metadata) { + for (nm in item_names) { + block[[nm]]$size <- param_sizes[[nm]] + } + } + block + }) +} + +create_skeleton <- function(model_variables, transformed_parameters, generated_quantities) { - target_params <- names(model_variables$parameters) + blocks <- "parameters" if (transformed_parameters) { - target_params <- c(target_params, - names(model_variables$transformed_parameters)) + blocks <- c(blocks, "transformed_parameters") } if (generated_quantities) { - target_params <- c(target_params, - names(model_variables$generated_quantities)) + blocks <- c(blocks, "generated_quantities") } - lapply(param_metadata[target_params], function(par_dims) { - array(0, dim = ifelse(length(par_dims) == 0, 1, par_dims)) + + nested_skeletion <- lapply(model_variables[blocks], function(block) { + lapply(block, function(item) { + array(0, dim = ifelse(length(item$size) == 0, 1, item$size)) + } + ) }) + skeleton <- unlist(nested_skeletion, recursive = FALSE) + stats::setNames(skeleton, gsub(paste0(blocks, ".", collapse = "|"), "", names(skeleton))) } get_standalone_hpp <- function(stan_file, stancflags) { diff --git a/inst/include/model_methods.cpp b/inst/include/model_methods.cpp index bff21fb64..c994d98e2 100644 --- a/inst/include/model_methods.cpp +++ b/inst/include/model_methods.cpp @@ -70,7 +70,7 @@ size_t get_num_upars(SEXP ext_model_ptr) { } // [[Rcpp::export]] -Rcpp::List get_param_metadata(SEXP ext_model_ptr) { +Rcpp::List get_param_sizes(SEXP ext_model_ptr) { Rcpp::XPtr ptr(ext_model_ptr); std::vector param_names; std::vector > param_dims; diff --git a/tests/testthat/test-model-methods.R b/tests/testthat/test-model-methods.R index da27134e7..940f983b3 100644 --- a/tests/testthat/test-model-methods.R +++ b/tests/testthat/test-model-methods.R @@ -66,6 +66,14 @@ test_that("Methods return correct values", { expect_equal(fit$constrain_pars(c(0.1), generated_quantities = FALSE), list(theta = 0.52497918747894001257)) + skeleton <- list( + theta = array(0, dim = 1), + log_lik = array(0, dim = data_list$N) + ) + + expect_equal(fit$constrain_pars(skeleton_only = TRUE), + skeleton) + upars <- fit$unconstrain_pars(cpars) expect_equal(upars, c(0.1)) }) @@ -108,6 +116,7 @@ test_that("methods error for incorrect inputs", { }) test_that("Methods error with already-compiled model", { + mod1 <- testing_model("bernoulli") mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") fit <- mod$sample(data = data_list, chains = 1) From 0882fd6d32f47ed492d5278993626d0f68799bb3 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 17 Oct 2022 10:42:59 +0300 Subject: [PATCH 2/5] Missed commit --- R/utils.R | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/R/utils.R b/R/utils.R index 32414d22d..6418600ce 100644 --- a/R/utils.R +++ b/R/utils.R @@ -619,8 +619,8 @@ add_param_model_sizes <- function(model_variables, param_sizes) { }) } -create_skeleton <- function(model_variables, - transformed_parameters, generated_quantities) { +create_skeleton <- function(model_variables, transformed_parameters, + generated_quantities) { blocks <- "parameters" if (transformed_parameters) { blocks <- c(blocks, "transformed_parameters") @@ -632,8 +632,7 @@ create_skeleton <- function(model_variables, nested_skeletion <- lapply(model_variables[blocks], function(block) { lapply(block, function(item) { array(0, dim = ifelse(length(item$size) == 0, 1, item$size)) - } - ) + }) }) skeleton <- unlist(nested_skeletion, recursive = FALSE) stats::setNames(skeleton, gsub(paste0(blocks, ".", collapse = "|"), "", names(skeleton))) From b95f857cde5cb16e53864de3be96bd3069d0f7d1 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 17 Oct 2022 10:49:21 +0300 Subject: [PATCH 3/5] Add doc --- R/fit.R | 2 ++ man/fit-method-constrain_pars.Rd | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/R/fit.R b/R/fit.R index 8883f0414..68bb465c6 100644 --- a/R/fit.R +++ b/R/fit.R @@ -469,6 +469,8 @@ CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars) #' implied by newly-constrained parameters (defaults to TRUE) #' @param generated_quantities (boolean) Whether to return generated quantities #' implied by newly-constrained parameters (defaults to TRUE) +#' @param skeleton_only (boolean) Whether to return only the "skeleton" needed by the +#' utils::relist function (defaults to FALSE) #' #' @examples #' \dontrun{ diff --git a/man/fit-method-constrain_pars.Rd b/man/fit-method-constrain_pars.Rd index cc4e8f821..0276d7d4a 100644 --- a/man/fit-method-constrain_pars.Rd +++ b/man/fit-method-constrain_pars.Rd @@ -6,9 +6,10 @@ \title{Transform a set of unconstrained parameter values to the constrained scale} \usage{ constrain_pars( - upars, + upars = NULL, transformed_parameters = TRUE, - generated_quantities = TRUE + generated_quantities = TRUE, + skeleton_only = FALSE ) } \arguments{ @@ -19,6 +20,9 @@ implied by newly-constrained parameters (defaults to TRUE)} \item{generated_quantities}{(boolean) Whether to return generated quantities implied by newly-constrained parameters (defaults to TRUE)} + +\item{skeleton_only}{(boolean) Whether to return only the "skeleton" needed by the +utils::relist function (defaults to FALSE)} } \description{ The \verb{$constrain_pars()} method transforms input parameters to From 7bc26d2b4d915564d2a4df10f4663bf268b67cc9 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Mon, 17 Oct 2022 11:42:30 +0300 Subject: [PATCH 4/5] Undo unnecessary runset change --- R/fit.R | 64 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/R/fit.R b/R/fit.R index 68bb465c6..1cbd84ae9 100644 --- a/R/fit.R +++ b/R/fit.R @@ -9,9 +9,10 @@ CmdStanFit <- R6::R6Class( classname = "CmdStanFit", public = list( functions = NULL, + runset = NULL, initialize = function(runset) { checkmate::assert_r6(runset, classes = "CmdStanRun") - private$runset_ <- runset + self$runset <- runset private$model_methods_env_ <- runset$model_methods_env() self$functions <- runset$standalone_env() @@ -21,7 +22,7 @@ CmdStanFit <- R6::R6Class( invisible(self) }, num_procs = function() { - private$runset_$num_procs() + self$runset$num_procs() }, print = function(variables = NULL, ..., digits = 2, max_rows = getOption("cmdstanr_max_rows", 10)) { if (is.null(private$draws_) && @@ -68,15 +69,14 @@ CmdStanFit <- R6::R6Class( expose_functions(self$functions, global, verbose) invisible(NULL) }, - variables = function() { private$runset_$args$model_variables } + variables = function() { self$runset$args$model_variables } ), private = list( draws_ = NULL, metadata_ = NULL, init_ = NULL, profiles_ = NULL, - model_methods_env_ = NULL, - runset_ = NULL + model_methods_env_ = NULL ) ) @@ -320,8 +320,8 @@ init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) { expose_model_methods(private$model_methods_env_, verbose, hessian) } initialize_model_env(private$model_methods_env_, self$data_file(), seed) - private$runset_$args$model_variables <- - add_param_model_sizes(private$runset_$args$model_variables, + self$runset$args$model_variables <- + add_param_model_sizes(self$runset$args$model_variables, private$model_methods_env_$param_sizes_) invisible(NULL) @@ -440,7 +440,7 @@ unconstrain_pars <- function(pars) { stop("The method has not been compiled, please call `init_model_methods()` first", call. = FALSE) } - model_par_names <- names(private$runset_$args$model_variables$parameters) + model_par_names <- names(self$runset$args$model_variables$parameters) prov_par_names <- names(pars) model_pars_not_prov <- which(!(model_par_names %in% prov_par_names)) @@ -452,7 +452,7 @@ unconstrain_pars <- function(pars) { # Ignore extraneous parameters model_pars_only <- pars[model_par_names] - stan_pars <- process_init_list(list(pars), num_procs = 1, private$runset_$args$model_variables) + stan_pars <- process_init_list(list(pars), num_procs = 1, self$runset$args$model_variables) private$model_methods_env_$unconstrain_pars(private$model_methods_env_$model_ptr_, stan_pars) } CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars) @@ -486,7 +486,7 @@ constrain_pars <- function(upars = NULL, transformed_parameters = TRUE, call. = FALSE) } - skeleton <- create_skeleton(private$runset_$args$model_variables, + skeleton <- create_skeleton(self$runset$args$model_variables, transformed_parameters, generated_quantities) @@ -667,13 +667,13 @@ CmdStanFit$set("public", name = "summary", value = summary) #' } #' cmdstan_summary <- function(flags = NULL) { - private$runset_$run_cmdstan_tool("stansummary", flags = flags) + self$runset$run_cmdstan_tool("stansummary", flags = flags) } CmdStanFit$set("public", name = "cmdstan_summary", value = cmdstan_summary) #' @rdname fit-method-cmdstan_summary cmdstan_diagnose <- function() { - private$runset_$run_cmdstan_tool("diagnose") + self$runset$run_cmdstan_tool("diagnose") } CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose) @@ -749,7 +749,7 @@ save_output_files <- function(dir = ".", basename = NULL, timestamp = TRUE, random = TRUE) { - private$runset_$save_output_files(dir, basename, timestamp, random) + self$runset$save_output_files(dir, basename, timestamp, random) } CmdStanFit$set("public", name = "save_output_files", value = save_output_files) @@ -758,7 +758,7 @@ save_latent_dynamics_files <- function(dir = ".", basename = NULL, timestamp = TRUE, random = TRUE) { - private$runset_$save_latent_dynamics_files(dir, basename, timestamp, random) + self$runset$save_latent_dynamics_files(dir, basename, timestamp, random) } CmdStanFit$set("public", name = "save_latent_dynamics_files", value = save_latent_dynamics_files) @@ -767,7 +767,7 @@ save_profile_files <- function(dir = ".", basename = NULL, timestamp = TRUE, random = TRUE) { - private$runset_$save_profile_files(dir, basename, timestamp, random) + self$runset$save_profile_files(dir, basename, timestamp, random) } CmdStanFit$set("public", name = "save_profile_files", value = save_profile_files) @@ -776,7 +776,7 @@ save_data_file <- function(dir = ".", basename = NULL, timestamp = TRUE, random = TRUE) { - private$runset_$save_data_file(dir, basename, timestamp, random) + self$runset$save_data_file(dir, basename, timestamp, random) } CmdStanFit$set("public", name = "save_data_file", value = save_data_file) @@ -784,25 +784,25 @@ CmdStanFit$set("public", name = "save_data_file", value = save_data_file) #' @param include_failed (logical) Should CmdStan runs that failed also be #' included? The default is `FALSE.` output_files <- function(include_failed = FALSE) { - private$runset_$output_files(include_failed) + self$runset$output_files(include_failed) } CmdStanFit$set("public", name = "output_files", value = output_files) #' @rdname fit-method-save_output_files profile_files <- function(include_failed = FALSE) { - private$runset_$profile_files(include_failed) + self$runset$profile_files(include_failed) } CmdStanFit$set("public", name = "profile_files", value = profile_files) #' @rdname fit-method-save_output_files latent_dynamics_files <- function(include_failed = FALSE) { - private$runset_$latent_dynamics_files(include_failed) + self$runset$latent_dynamics_files(include_failed) } CmdStanFit$set("public", name = "latent_dynamics_files", value = latent_dynamics_files) #' @rdname fit-method-save_output_files data_file <- function() { - private$runset_$data_file() + self$runset$data_file() } CmdStanFit$set("public", name = "data_file", value = data_file) @@ -838,7 +838,7 @@ CmdStanFit$set("public", name = "data_file", value = data_file) #' } #' time <- function() { - private$runset_$time() + self$runset$time() } CmdStanFit$set("public", name = "time", value = time) @@ -876,7 +876,7 @@ CmdStanFit$set("public", name = "time", value = time) output <- function(id = NULL) { # MCMC has separate implementation but doc is shared # Non-MCMC fit is obtained with one process only so id is ignored - cat(paste(private$runset_$procs$proc_output(1), collapse = "\n")) + cat(paste(self$runset$procs$proc_output(1), collapse = "\n")) } CmdStanFit$set("public", name = "output", value = output) @@ -936,7 +936,7 @@ CmdStanFit$set("public", name = "metadata", value = metadata) #' } #' return_codes <- function() { - private$runset_$procs$return_codes() + self$runset$procs$return_codes() } CmdStanFit$set("public", name = "return_codes", value = return_codes) @@ -1019,7 +1019,7 @@ CmdStanFit$set("public", name = "profiles", value = profiles) #' } #' code <- function() { - stan_code <- private$runset_$stan_code() + stan_code <- self$runset$stan_code() if (is.null(stan_code)) { warning("'$code()' will return NULL because the 'CmdStanModel' was not created with a Stan file.", call. = FALSE) } @@ -1097,7 +1097,7 @@ CmdStanMCMC <- R6::R6Class( if (runset$args$method_args$fixed_param) { private$read_csv_(variables = "", sampler_diagnostics = "") } else { - diagnostics <- private$runset_$args$method_args$diagnostics + diagnostics <- self$runset$args$method_args$diagnostics private$read_csv_( variables = "", sampler_diagnostics = convert_hmc_diagnostic_names(diagnostics) @@ -1109,9 +1109,9 @@ CmdStanMCMC <- R6::R6Class( # override the CmdStanFit output method output = function(id = NULL) { if (is.null(id)) { - private$runset_$procs$proc_output() + self$runset$procs$proc_output() } else { - cat(paste(private$runset_$procs$proc_output(id), collapse = "\n")) + cat(paste(self$runset$procs$proc_output(id), collapse = "\n")) } }, @@ -1712,7 +1712,7 @@ CmdStanGQ <- R6::R6Class( inherit = CmdStanFit, public = list( fitted_params_files = function() { - private$runset_$args$method_args$fitted_params + self$runset$args$method_args$fitted_params }, num_chains = function() { super$num_procs() @@ -1751,9 +1751,9 @@ CmdStanGQ <- R6::R6Class( # override CmdStanFit output method output = function(id = NULL) { if (is.null(id)) { - private$runset_$procs$proc_output() + self$runset$procs$proc_output() } else { - cat(paste(private$runset_$procs$proc_output(id), collapse = "\n")) + cat(paste(self$runset$procs$proc_output(id), collapse = "\n")) } } ), @@ -1821,8 +1821,8 @@ CmdStanDiagnose <- R6::R6Class( runset = NULL, initialize = function(runset) { checkmate::assert_r6(runset, classes = "CmdStanRun") - private$runset_ <- runset - csv_data <- read_cmdstan_csv(private$runset_$output_files()) + self$runset <- runset + csv_data <- read_cmdstan_csv(self$runset$output_files()) private$metadata_ <- csv_data$metadata private$gradients_ <- csv_data$gradients private$lp_ <- csv_data$lp From f979f2a93351b645c57a1d85efca62a81400846e Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Thu, 20 Oct 2022 08:15:32 +0300 Subject: [PATCH 5/5] Remove redundant changes --- R/fit.R | 22 ++++++----------- R/utils.R | 38 +++++++++-------------------- inst/include/model_methods.cpp | 2 +- man/fit-method-constrain_pars.Rd | 2 +- tests/testthat/test-model-methods.R | 1 - 5 files changed, 20 insertions(+), 45 deletions(-) diff --git a/R/fit.R b/R/fit.R index 1cbd84ae9..bb56d3b04 100644 --- a/R/fit.R +++ b/R/fit.R @@ -8,8 +8,8 @@ CmdStanFit <- R6::R6Class( classname = "CmdStanFit", public = list( - functions = NULL, runset = NULL, + functions = NULL, initialize = function(runset) { checkmate::assert_r6(runset, classes = "CmdStanRun") self$runset <- runset @@ -17,7 +17,7 @@ CmdStanFit <- R6::R6Class( self$functions <- runset$standalone_env() if (!is.null(private$model_methods_env_$model_ptr)) { - initialize_model_env(private$model_methods_env_, self$data_file(), 0) + initialize_model_pointer(private$model_methods_env_, self$data_file(), 0) } invisible(self) }, @@ -68,8 +68,7 @@ CmdStanFit <- R6::R6Class( expose_functions = function(global = FALSE, verbose = FALSE) { expose_functions(self$functions, global, verbose) invisible(NULL) - }, - variables = function() { self$runset$args$model_variables } + } ), private = list( draws_ = NULL, @@ -319,11 +318,7 @@ init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) { if (is.null(private$model_methods_env_$model_ptr)) { expose_model_methods(private$model_methods_env_, verbose, hessian) } - initialize_model_env(private$model_methods_env_, self$data_file(), seed) - self$runset$args$model_variables <- - add_param_model_sizes(self$runset$args$model_variables, - private$model_methods_env_$param_sizes_) - + initialize_model_pointer(private$model_methods_env_, self$data_file(), seed) invisible(NULL) } CmdStanFit$set("public", name = "init_model_methods", value = init_model_methods) @@ -478,18 +473,17 @@ CmdStanFit$set("public", name = "unconstrain_pars", value = unconstrain_pars) #' fit_mcmc$constrain_pars(upars = c(0.5, 1.2, 1.1, 2.2, 1.1)) #' } #' -constrain_pars <- function(upars = NULL, transformed_parameters = TRUE, - generated_quantities = TRUE, +constrain_pars <- function(upars, transformed_parameters = TRUE, generated_quantities = TRUE, skeleton_only = FALSE) { if (is.null(private$model_methods_env_$model_ptr)) { stop("The method has not been compiled, please call `init_model_methods()` first", call. = FALSE) } - skeleton <- create_skeleton(self$runset$args$model_variables, + skeleton <- create_skeleton(private$model_methods_env_$param_metadata_, + self$runset$args$model_variables, transformed_parameters, generated_quantities) - if (skeleton_only) { return(skeleton) } @@ -498,12 +492,10 @@ constrain_pars <- function(upars = NULL, transformed_parameters = TRUE, stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ", length(upars), " were provided!", call. = FALSE) } - cpars <- private$model_methods_env_$constrain_pars( private$model_methods_env_$model_ptr_, private$model_methods_env_$model_rng_, upars, transformed_parameters, generated_quantities) - utils::relist(cpars, skeleton) } CmdStanFit$set("public", name = "constrain_pars", value = constrain_pars) diff --git a/R/utils.R b/R/utils.R index 6418600ce..b1057e07b 100644 --- a/R/utils.R +++ b/R/utils.R @@ -597,45 +597,29 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) { invisible(NULL) } -initialize_model_env <- function(env, data, seed = 0) { +initialize_model_pointer <- function(env, data, seed = 0) { ptr_and_rng <- env$model_ptr(data, seed) env$model_ptr_ <- ptr_and_rng$model_ptr env$model_rng_ <- ptr_and_rng$base_rng env$num_upars_ <- env$get_num_upars(env$model_ptr_) - env$param_sizes_ <- env$get_param_sizes(env$model_ptr_) + env$param_metadata_ <- env$get_param_metadata(env$model_ptr_) invisible(NULL) } -add_param_model_sizes <- function(model_variables, param_sizes) { - lapply(model_variables, function(block) { - item_names <- names(block) - item_in_metadata <- any(item_names %in% names(param_sizes)) - if (item_in_metadata) { - for (nm in item_names) { - block[[nm]]$size <- param_sizes[[nm]] - } - } - block - }) -} - -create_skeleton <- function(model_variables, transformed_parameters, - generated_quantities) { - blocks <- "parameters" +create_skeleton <- function(param_metadata, model_variables, + transformed_parameters, generated_quantities) { + target_params <- names(model_variables$parameters) if (transformed_parameters) { - blocks <- c(blocks, "transformed_parameters") + target_params <- c(target_params, + names(model_variables$transformed_parameters)) } if (generated_quantities) { - blocks <- c(blocks, "generated_quantities") + target_params <- c(target_params, + names(model_variables$generated_quantities)) } - - nested_skeletion <- lapply(model_variables[blocks], function(block) { - lapply(block, function(item) { - array(0, dim = ifelse(length(item$size) == 0, 1, item$size)) - }) + lapply(param_metadata[target_params], function(par_dims) { + array(0, dim = ifelse(length(par_dims) == 0, 1, par_dims)) }) - skeleton <- unlist(nested_skeletion, recursive = FALSE) - stats::setNames(skeleton, gsub(paste0(blocks, ".", collapse = "|"), "", names(skeleton))) } get_standalone_hpp <- function(stan_file, stancflags) { diff --git a/inst/include/model_methods.cpp b/inst/include/model_methods.cpp index c994d98e2..bff21fb64 100644 --- a/inst/include/model_methods.cpp +++ b/inst/include/model_methods.cpp @@ -70,7 +70,7 @@ size_t get_num_upars(SEXP ext_model_ptr) { } // [[Rcpp::export]] -Rcpp::List get_param_sizes(SEXP ext_model_ptr) { +Rcpp::List get_param_metadata(SEXP ext_model_ptr) { Rcpp::XPtr ptr(ext_model_ptr); std::vector param_names; std::vector > param_dims; diff --git a/man/fit-method-constrain_pars.Rd b/man/fit-method-constrain_pars.Rd index 0276d7d4a..f25d206a9 100644 --- a/man/fit-method-constrain_pars.Rd +++ b/man/fit-method-constrain_pars.Rd @@ -6,7 +6,7 @@ \title{Transform a set of unconstrained parameter values to the constrained scale} \usage{ constrain_pars( - upars = NULL, + upars, transformed_parameters = TRUE, generated_quantities = TRUE, skeleton_only = FALSE diff --git a/tests/testthat/test-model-methods.R b/tests/testthat/test-model-methods.R index 940f983b3..d1b6f7fe8 100644 --- a/tests/testthat/test-model-methods.R +++ b/tests/testthat/test-model-methods.R @@ -116,7 +116,6 @@ test_that("methods error for incorrect inputs", { }) test_that("Methods error with already-compiled model", { - mod1 <- testing_model("bernoulli") mod <- testing_model("bernoulli") data_list <- testing_data("bernoulli") fit <- mod$sample(data = data_list, chains = 1)