Skip to content

Commit

Permalink
Rework to use model seed (by default) and chain offset
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Apr 3, 2023
1 parent 6e696f9 commit 018e69b
Show file tree
Hide file tree
Showing 16 changed files with 197 additions and 138 deletions.
45 changes: 28 additions & 17 deletions R/R/bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ StanModel <- R6::R6Class("StanModel",
#' Create a Stan Model instance.
#' @param lib A path to a compiled BridgeStan Shared Object file.
#' @param data Either a JSON string literal or a path to a data file in JSON format ending in ".json".
#' @param rng_seed Seed for the RNG in the model object.
#' @param seed Seed for the RNG in the model object.
#' @return A new StanModel.
initialize = function(lib, data, rng_seed) {
initialize = function(lib, data, seed) {
if (.Platform$OS.type == "windows"){
lib_old <- lib
lib <- paste0(tools::file_path_sans_ext(lib), ".dll")
file.copy(from=lib_old, to=lib)
}

private$seed <- seed
private$lib <- tools::file_path_as_absolute(lib)
private$lib_name <- tools::file_path_sans_ext(basename(lib))
if (is.loaded("construct_R", PACKAGE = private$lib_name)) {
Expand All @@ -32,7 +33,7 @@ StanModel <- R6::R6Class("StanModel",

dyn.load(private$lib, PACKAGE = private$lib_name)
ret <- .C("bs_construct_R",
as.character(data), as.integer(rng_seed),
as.character(data), as.integer(seed),
ptr_out = raw(8),
err_msg = as.character(""),
err_ptr = raw(8),
Expand Down Expand Up @@ -124,17 +125,18 @@ StanModel <- R6::R6Class("StanModel",
#' @param theta_unc The vector of unconstrained parameters.
#' @param include_tp Whether to also output the transformed parameters of the model.
#' @param include_gq Whether to also output the generated quantities of the model.
#' @param seed The seed for the random number generator. One of
#' `rng` or `seed` must be specified if `include_gq` is `True`.
#' @param chain_id A chain ID used to offset a PRNG seeded with the model's seed
#' which should be unique between calls. One of `rng` or `seed` must be specified
#' if `include_gq` is `True`.
#' @param rng The random number generator to use. See `StanModel$new_rng()`.
#' One of `rng` or `seed` must be specified if `include_gq` is `True`.
#' @return The constrained parameters of the model.
param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE, seed, rng) {
if (missing(seed) && missing(rng)){
param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE, chain_id, rng) {
if (missing(chain_id) && missing(rng)){
if (include_gq) {
stop("Either seed or rng must be specified if include_gq is True.")
stop("Either chain_id or rng must be specified if include_gq is True.")
} else {
seed = 0
chain_id <- 0
}
}
if (!missing(rng)){
Expand All @@ -148,10 +150,10 @@ StanModel <- R6::R6Class("StanModel",
PACKAGE = private$lib_name
)
} else {
vars <- .C("bs_param_constrain_seed_R", as.raw(private$model),
vars <- .C("bs_param_constrain_id_R", as.raw(private$model),
as.logical(include_tp), as.logical(include_gq), as.double(theta_unc),
theta = double(self$param_num(include_tp = include_tp, include_gq = include_gq)),
seed = as.integer(seed),
chain_id = as.integer(chain_id),
return_code = as.integer(0),
err_msg = as.character(""),
err_ptr = raw(8),
Expand All @@ -166,10 +168,16 @@ StanModel <- R6::R6Class("StanModel",
},
#' @description
#' Create a new persistent PRNG object for use in `param_constrain()`.
#' @param seed The seed for the PRNG.
#' @param chain_id Identifier for a sequence in the RNG. This should be made
#' a distinct number for each PRNG created with the same seed
#' (for example, 1:N for N PRNGS).
#' @param seed The seed for the PRNG. If this is not specified, the model's seed is used.
#' @return A `StanRNG` object.
new_rng = function(seed) {
StanRNG$new(private$lib_name, seed)
new_rng = function(chain_id, seed) {
if (missing(seed)){
seed <- private$seed
}
StanRNG$new(private$lib_name, seed, chain_id)
},
#' @description
#' Returns a vector of unconstrained parameters give the constrained parameters.
Expand Down Expand Up @@ -284,6 +292,7 @@ StanModel <- R6::R6Class("StanModel",
lib = NA,
lib_name = NA,
model = NA,
seed = NA,
finalize = function() {
.C("bs_destruct_R",
as.raw(private$model),
Expand Down Expand Up @@ -315,13 +324,15 @@ StanRNG <- R6::R6Class("StanRNG",
#' @description
#' Create a StanRng
#' @param lib_name The name of the Stan dynamic library.
#' @param rng_seed The seed for the RNG.
#' @param seed The seed for the RNG.
#' @param chain_id The chain ID for the RNG.
#' @return A new StanRNG.
initialize = function(lib_name, rng_seed) {
initialize = function(lib_name, seed, chain_id) {
private$lib_name <- lib_name

vars <- .C("bs_construct_rng_R",
as.integer(rng_seed),
as.integer(seed),
as.integer(chain_id),
ptr_out = raw(8),
err_msg = as.character(""),
err_ptr = raw(8),
Expand Down
19 changes: 12 additions & 7 deletions R/man/StanModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions R/man/StanRNG.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions R/tests/testthat/test-bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ test_that("param_constrain handles rng arguments", {
expect_equal(4, length(full$param_constrain(c(1.2), include_tp=TRUE, include_gq=TRUE, rng=rng)))

# check reproducibility
expect_equal(full$param_constrain(c(1.2), include_gq=TRUE, seed=1234),
full$param_constrain(c(1.2), include_gq=TRUE, seed=1234))
expect_equal(full$param_constrain(c(1.2), include_gq=TRUE, chain_id=3),
full$param_constrain(c(1.2), include_gq=TRUE, chain_id=3))

# require at least one present
expect_error(full$param_constrain(c(1.2), include_gq=TRUE), "seed or rng must be specified")
expect_error(full$param_constrain(c(1.2), include_gq=TRUE), "chain_id or rng must be specified")
})


Expand All @@ -142,5 +142,5 @@ test_that("param_constrain propagates errors", {

m2 <- load_model("throw_gq",include_data=FALSE)
m2$param_constrain(c(1.2)) # no error
expect_error(m2$param_constrain(c(1.2), include_gq=TRUE, seed=123), "find this text: gqfails")
expect_error(m2$param_constrain(c(1.2), include_gq=TRUE, chain_id=1), "find this text: gqfails")
})
Loading

0 comments on commit 018e69b

Please sign in to comment.