Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking change: Implement param_constrain in a thread-safe way #98

Merged
merged 21 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions R/R/bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ StanModel <- R6::R6Class("StanModel",
#' Create a Stan Model instance.
#' @param lib A path to a compiled BridgeStan Shared Object file.
#' @param data Either a JSON string literal or a path to a data file in JSON format ending in ".json".
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
#' @param rng_seed Seed for the RNG in the model object.
#' @param chain_id Used to offset the RNG by a fixed amount.
#' @param seed Seed for the RNG used in constructing the model.
#' @return A new StanModel.
initialize = function(lib, data, rng_seed, chain_id) {
initialize = function(lib, data, seed) {
if (.Platform$OS.type == "windows"){
lib_old <- lib
lib <- paste0(tools::file_path_sans_ext(lib), ".dll")
file.copy(from=lib_old, to=lib)
}

private$seed <- seed
private$lib <- tools::file_path_as_absolute(lib)
private$lib_name <- tools::file_path_sans_ext(basename(lib))
if (is.loaded("construct_R", PACKAGE = private$lib_name)) {
Expand All @@ -33,7 +33,7 @@ StanModel <- R6::R6Class("StanModel",

dyn.load(private$lib, PACKAGE = private$lib_name)
ret <- .C("bs_construct_R",
as.character(data), as.integer(rng_seed), as.integer(chain_id),
as.character(data), as.integer(seed),
ptr_out = raw(8),
err_msg = as.character(""),
err_ptr = raw(8),
Expand Down Expand Up @@ -140,22 +140,40 @@ StanModel <- R6::R6Class("StanModel",
#' @param theta_unc The vector of unconstrained parameters.
#' @param include_tp Whether to also output the transformed parameters of the model.
#' @param include_gq Whether to also output the generated quantities of the model.
#' @param rng The random number generator to use if `include_gq` is `TRUE`. See `StanModel$new_rng()`.
#' @return The constrained parameters of the model.
param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE) {
param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE, rng) {
if (missing(rng)) {
if (include_gq){
stop("A rng must be provided if include_gq is True.")
}
rng_ptr <- as.integer(0)
} else {
rng_ptr <- as.raw(rng$rng)
}
vars <- .C("bs_param_constrain_R", as.raw(private$model),
as.logical(include_tp), as.logical(include_gq), as.double(theta_unc),
theta = double(self$param_num(include_tp = include_tp, include_gq = include_gq)),
rng = rng_ptr,
return_code = as.integer(0),
err_msg = as.character(""),
err_ptr = raw(8),
PACKAGE = private$lib_name
)

if (vars$return_code) {
stop(handle_error(private$lib_name, vars$err_msg, vars$err_ptr, "param_constrain"))
}
vars$theta
},
#' @description
#' Create a new persistent PRNG object for use in `param_constrain()`.
#' @param seed The seed for the PRNG.
#' @return A `StanRNG` object.
new_rng = function(seed) {
StanRNG$new(private$lib_name, seed)
},
#' @description
#' Returns a vector of unconstrained parameters give the constrained parameters.
#'
#' It is assumed that these will be in the same order as internally represented by
Expand Down Expand Up @@ -268,6 +286,7 @@ StanModel <- R6::R6Class("StanModel",
lib = NA,
lib_name = NA,
model = NA,
seed = NA,
finalize = function() {
.C("bs_destruct_R",
as.raw(private$model),
Expand All @@ -287,3 +306,46 @@ handle_error <- function(lib_name, err_msg, err_ptr, function_name) {
return(err_msg)
}
}

#' StanRNG
#'
#' RNG object for use with `StanModel$param_constrain()`
#' @field rng The pointer to the RNG object.
#' @keywords internal
StanRNG <- R6::R6Class("StanRNG",
public = list(
#' @description
#' Create a StanRng
#' @param lib_name The name of the Stan dynamic library.
#' @param seed The seed for the RNG.
#' @return A new StanRNG.
initialize = function(lib_name, seed) {
private$lib_name <- lib_name

vars <- .C("bs_construct_rng_R",
as.integer(seed),
ptr_out = raw(8),
err_msg = as.character(""),
err_ptr = raw(8),
PACKAGE = private$lib_name
)

if (all(vars$ptr_out == 0)) {
stop(handle_error("construct_rng", vars$err_msg, vars$err_ptr, private$lib_name))
} else {
self$rng <- vars$ptr_out
}
},
rng = NA
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
),
private = list(
lib_name = NA,
finalize = function() {
.C("bs_destruct_rng_R",
as.raw(self$rng),
PACKAGE = private$lib_name
)
}
),
cloneable=FALSE
)
46 changes: 41 additions & 5 deletions R/man/StanModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions R/man/StanRNG.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R/man/handle_error.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 20 additions & 3 deletions R/tests/testthat/test-bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ load_model <- function(name, include_data=TRUE) {
} else {
data = ""
}
model <- StanModel$new(file.path(base,paste0("/test_models/", name, "/",name,"_model.so")), data, 1234, 0)
model <- StanModel$new(file.path(base,paste0("/test_models/", name, "/",name,"_model.so")), data, 1234)
return(model)
}

test_that("missing data throws error", {
expect_error(StanModel$new(file.path(base,paste0("/test_models/simple/simple_model.so")), "", 1234, 0))
expect_error(load_model("simple",include_data=FALSE))
})

simple <- load_model("simple")
Expand Down Expand Up @@ -106,6 +106,23 @@ test_that("param_unconstrain works for a nontrivial case", {
expect_equal(c, a)
})

test_that("param_constrain handles rng arguments", {
full <- load_model("full", include_data=FALSE)
expect_equal(1, length(full$param_constrain(c(1.2))))
expect_equal(2, length(full$param_constrain(c(1.2), include_tp=TRUE)))
rng <- full$new_rng(123)
expect_equal(3, length(full$param_constrain(c(1.2), include_gq=TRUE, rng=rng)))
expect_equal(4, length(full$param_constrain(c(1.2), include_tp=TRUE, include_gq=TRUE, rng=rng)))

# check reproducibility
expect_equal(full$param_constrain(c(1.2), include_gq=TRUE, rng=full$new_rng(456)),
full$param_constrain(c(1.2), include_gq=TRUE, rng=full$new_rng(456)))

# require at least one present
expect_error(full$param_constrain(c(1.2), include_gq=TRUE), "rng must be provided")
})


test_that("constructor propagates errors", {
expect_error(load_model("throw_data",include_data=FALSE), "find this text: datafails")
})
Expand All @@ -125,5 +142,5 @@ test_that("param_constrain propagates errors", {

m2 <- load_model("throw_gq",include_data=FALSE)
m2$param_constrain(c(1.2)) # no error
expect_error(m2$param_constrain(c(1.2), include_gq=TRUE), "find this text: gqfails")
expect_error(m2$param_constrain(c(1.2), include_gq=TRUE, rng=m2$new_rng(1234)), "find this text: gqfails")
})
2 changes: 1 addition & 1 deletion c-example/example.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ int main(int argc, char** argv) {

// this could potentially error, and we may get information back about why.
char* err;
bs_model_rng* model = bs_construct(data, 123, 0, &err);
bs_model* model = bs_construct(data, 123, &err);
if (!model) {
if (err) {
printf("Error: %s", err);
Expand Down
Loading