Skip to content

Commit

Permalink
Breaking change: Implement param_constrain in a thread-safe way (#98)
Browse files Browse the repository at this point in the history
* Start C++ changes for moving RNGs outside model object

* Checkpoint

* Checkpointing; rebased on error msg PR to avoid stepping on myself

* Update doc comments

* Update Python tests

* Start julia

* Update Julia tests, doc

* R interface work

* Documentation

* Update c-example

* Rework to use model seed (by default) and chain offset

* Expose seed AND chain in C API, keep higher-level APIs as prior commit

* Formatting

* Update to reflect that destructor cannot fail

* Typo fix

* Simplify interface; remove concept of chain ID

* Doc changes and naming consistency

* Consistently name construct/destruct functions

* Test thread-safety of param_constrain in Julia

* rework julia test threaded model: full

---------

Co-authored-by: Edward A. Roualdes <eroualdes@csuchico.edu>
  • Loading branch information
WardBrian and roualdes authored Apr 14, 2023
1 parent 9b2fc59 commit 38f9ac0
Show file tree
Hide file tree
Showing 27 changed files with 818 additions and 335 deletions.
78 changes: 70 additions & 8 deletions R/R/bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ StanModel <- R6::R6Class("StanModel",
#' @description
#' Create a Stan Model instance.
#' @param lib A path to a compiled BridgeStan Shared Object file.
#' @param data Either a JSON string literal or a path to a data file in JSON format ending in ".json".
#' @param rng_seed Seed for the RNG in the model object.
#' @param chain_id Used to offset the RNG by a fixed amount.
#' @param data Either a JSON string literal,a path to a data file in JSON format ending in ".json", or the empty string.
#' @param seed Seed for the RNG used in constructing the model.
#' @return A new StanModel.
initialize = function(lib, data, rng_seed, chain_id) {
initialize = function(lib, data, seed) {
if (.Platform$OS.type == "windows"){
lib_old <- lib
lib <- paste0(tools::file_path_sans_ext(lib), ".dll")
file.copy(from=lib_old, to=lib)
}

private$seed <- seed
private$lib <- tools::file_path_as_absolute(lib)
private$lib_name <- tools::file_path_sans_ext(basename(lib))
if (is.loaded("construct_R", PACKAGE = private$lib_name)) {
Expand All @@ -32,8 +32,8 @@ StanModel <- R6::R6Class("StanModel",
}

dyn.load(private$lib, PACKAGE = private$lib_name)
ret <- .C("bs_construct_R",
as.character(data), as.integer(rng_seed), as.integer(chain_id),
ret <- .C("bs_model_construct_R",
as.character(data), as.integer(seed),
ptr_out = raw(8),
err_msg = as.character(""),
err_ptr = raw(8),
Expand Down Expand Up @@ -140,22 +140,40 @@ StanModel <- R6::R6Class("StanModel",
#' @param theta_unc The vector of unconstrained parameters.
#' @param include_tp Whether to also output the transformed parameters of the model.
#' @param include_gq Whether to also output the generated quantities of the model.
#' @param rng The random number generator to use if `include_gq` is `TRUE`. See `StanModel$new_rng()`.
#' @return The constrained parameters of the model.
param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE) {
param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE, rng) {
if (missing(rng)) {
if (include_gq){
stop("A rng must be provided if include_gq is True.")
}
rng_ptr <- as.integer(0)
} else {
rng_ptr <- as.raw(rng$ptr)
}
vars <- .C("bs_param_constrain_R", as.raw(private$model),
as.logical(include_tp), as.logical(include_gq), as.double(theta_unc),
theta = double(self$param_num(include_tp = include_tp, include_gq = include_gq)),
rng = rng_ptr,
return_code = as.integer(0),
err_msg = as.character(""),
err_ptr = raw(8),
PACKAGE = private$lib_name
)

if (vars$return_code) {
stop(handle_error(private$lib_name, vars$err_msg, vars$err_ptr, "param_constrain"))
}
vars$theta
},
#' @description
#' Create a new persistent PRNG object for use in `param_constrain()`.
#' @param seed The seed for the PRNG.
#' @return A `StanRNG` object.
new_rng = function(seed) {
StanRNG$new(private$lib_name, seed)
},
#' @description
#' Returns a vector of unconstrained parameters give the constrained parameters.
#'
#' It is assumed that these will be in the same order as internally represented by
Expand Down Expand Up @@ -268,8 +286,9 @@ StanModel <- R6::R6Class("StanModel",
lib = NA,
lib_name = NA,
model = NA,
seed = NA,
finalize = function() {
.C("bs_destruct_R",
.C("bs_model_destruct_R",
as.raw(private$model),
PACKAGE = private$lib_name
)
Expand All @@ -287,3 +306,46 @@ handle_error <- function(lib_name, err_msg, err_ptr, function_name) {
return(err_msg)
}
}

#' StanRNG
#'
#' RNG object for use with `StanModel$param_constrain()`
#' @field rng The pointer to the RNG object.
#' @keywords internal
StanRNG <- R6::R6Class("StanRNG",
public = list(
#' @description
#' Create a StanRng
#' @param lib_name The name of the Stan dynamic library.
#' @param seed The seed for the RNG.
#' @return A new StanRNG.
initialize = function(lib_name, seed) {
private$lib_name <- lib_name

vars <- .C("bs_rng_construct_R",
as.integer(seed),
ptr_out = raw(8),
err_msg = as.character(""),
err_ptr = raw(8),
PACKAGE = private$lib_name
)

if (all(vars$ptr_out == 0)) {
stop(handle_error("construct_rng", vars$err_msg, vars$err_ptr, private$lib_name))
} else {
self$ptr <- vars$ptr_out
}
},
ptr = NA
),
private = list(
lib_name = NA,
finalize = function() {
.C("bs_rng_destruct_R",
as.raw(self$ptr),
PACKAGE = private$lib_name
)
}
),
cloneable=FALSE
)
46 changes: 41 additions & 5 deletions R/man/StanModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions R/man/StanRNG.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion R/man/handle_error.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 20 additions & 3 deletions R/tests/testthat/test-bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ load_model <- function(name, include_data=TRUE) {
} else {
data = ""
}
model <- StanModel$new(file.path(base,paste0("/test_models/", name, "/",name,"_model.so")), data, 1234, 0)
model <- StanModel$new(file.path(base,paste0("/test_models/", name, "/",name,"_model.so")), data, 1234)
return(model)
}

test_that("missing data throws error", {
expect_error(StanModel$new(file.path(base,paste0("/test_models/simple/simple_model.so")), "", 1234, 0))
expect_error(load_model("simple",include_data=FALSE))
})

simple <- load_model("simple")
Expand Down Expand Up @@ -106,6 +106,23 @@ test_that("param_unconstrain works for a nontrivial case", {
expect_equal(c, a)
})

test_that("param_constrain handles rng arguments", {
full <- load_model("full", include_data=FALSE)
expect_equal(1, length(full$param_constrain(c(1.2))))
expect_equal(2, length(full$param_constrain(c(1.2), include_tp=TRUE)))
rng <- full$new_rng(123)
expect_equal(3, length(full$param_constrain(c(1.2), include_gq=TRUE, rng=rng)))
expect_equal(4, length(full$param_constrain(c(1.2), include_tp=TRUE, include_gq=TRUE, rng=rng)))

# check reproducibility
expect_equal(full$param_constrain(c(1.2), include_gq=TRUE, rng=full$new_rng(456)),
full$param_constrain(c(1.2), include_gq=TRUE, rng=full$new_rng(456)))

# require at least one present
expect_error(full$param_constrain(c(1.2), include_gq=TRUE), "rng must be provided")
})


test_that("constructor propagates errors", {
expect_error(load_model("throw_data",include_data=FALSE), "find this text: datafails")
})
Expand All @@ -125,5 +142,5 @@ test_that("param_constrain propagates errors", {

m2 <- load_model("throw_gq",include_data=FALSE)
m2$param_constrain(c(1.2)) # no error
expect_error(m2$param_constrain(c(1.2), include_gq=TRUE), "find this text: gqfails")
expect_error(m2$param_constrain(c(1.2), include_gq=TRUE, rng=m2$new_rng(1234)), "find this text: gqfails")
})
4 changes: 2 additions & 2 deletions c-example/example.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ int main(int argc, char** argv) {

// this could potentially error, and we may get information back about why.
char* err;
bs_model_rng* model = bs_construct(data, 123, 0, &err);
bs_model* model = bs_model_construct(data, 123, &err);
if (!model) {
if (err) {
printf("Error: %s", err);
Expand All @@ -26,6 +26,6 @@ int main(int argc, char** argv) {
printf("This model's name is %s.\n", bs_name(model));
printf("It has %d parameters.\n", bs_param_num(model, 0, 0));

bs_destruct(model);
bs_model_destruct(model);
return 0;
}
Loading

0 comments on commit 38f9ac0

Please sign in to comment.