diff --git a/R/R/bridgestan.R b/R/R/bridgestan.R index d68d4e2e..8ba70df2 100755 --- a/R/R/bridgestan.R +++ b/R/R/bridgestan.R @@ -11,15 +11,16 @@ StanModel <- R6::R6Class("StanModel", #' Create a Stan Model instance. #' @param lib A path to a compiled BridgeStan Shared Object file. #' @param data Either a JSON string literal or a path to a data file in JSON format ending in ".json". - #' @param rng_seed Seed for the RNG in the model object. + #' @param seed Seed for the RNG in the model object. #' @return A new StanModel. - initialize = function(lib, data, rng_seed) { + initialize = function(lib, data, seed) { if (.Platform$OS.type == "windows"){ lib_old <- lib lib <- paste0(tools::file_path_sans_ext(lib), ".dll") file.copy(from=lib_old, to=lib) } + private$seed <- seed private$lib <- tools::file_path_as_absolute(lib) private$lib_name <- tools::file_path_sans_ext(basename(lib)) if (is.loaded("construct_R", PACKAGE = private$lib_name)) { @@ -32,7 +33,7 @@ StanModel <- R6::R6Class("StanModel", dyn.load(private$lib, PACKAGE = private$lib_name) ret <- .C("bs_construct_R", - as.character(data), as.integer(rng_seed), + as.character(data), as.integer(seed), ptr_out = raw(8), err_msg = as.character(""), err_ptr = raw(8), @@ -124,17 +125,18 @@ StanModel <- R6::R6Class("StanModel", #' @param theta_unc The vector of unconstrained parameters. #' @param include_tp Whether to also output the transformed parameters of the model. #' @param include_gq Whether to also output the generated quantities of the model. - #' @param seed The seed for the random number generator. One of - #' `rng` or `seed` must be specified if `include_gq` is `True`. + #' @param chain_id A chain ID used to offset a PRNG seeded with the model's seed + #' which should be unique between calls. One of `rng` or `seed` must be specified + #' if `include_gq` is `True`. #' @param rng The random number generator to use. See `StanModel$new_rng()`. #' One of `rng` or `seed` must be specified if `include_gq` is `True`. #' @return The constrained parameters of the model. - param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE, seed, rng) { - if (missing(seed) && missing(rng)){ + param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE, chain_id, rng) { + if (missing(chain_id) && missing(rng)){ if (include_gq) { - stop("Either seed or rng must be specified if include_gq is True.") + stop("Either chain_id or rng must be specified if include_gq is True.") } else { - seed = 0 + chain_id <- 0 } } if (!missing(rng)){ @@ -148,10 +150,10 @@ StanModel <- R6::R6Class("StanModel", PACKAGE = private$lib_name ) } else { - vars <- .C("bs_param_constrain_seed_R", as.raw(private$model), + vars <- .C("bs_param_constrain_id_R", as.raw(private$model), as.logical(include_tp), as.logical(include_gq), as.double(theta_unc), theta = double(self$param_num(include_tp = include_tp, include_gq = include_gq)), - seed = as.integer(seed), + chain_id = as.integer(chain_id), return_code = as.integer(0), err_msg = as.character(""), err_ptr = raw(8), @@ -166,10 +168,16 @@ StanModel <- R6::R6Class("StanModel", }, #' @description #' Create a new persistent PRNG object for use in `param_constrain()`. - #' @param seed The seed for the PRNG. + #' @param chain_id Identifier for a sequence in the RNG. This should be made + #' a distinct number for each PRNG created with the same seed + #' (for example, 1:N for N PRNGS). + #' @param seed The seed for the PRNG. If this is not specified, the model's seed is used. #' @return A `StanRNG` object. - new_rng = function(seed) { - StanRNG$new(private$lib_name, seed) + new_rng = function(chain_id, seed) { + if (missing(seed)){ + seed <- private$seed + } + StanRNG$new(private$lib_name, seed, chain_id) }, #' @description #' Returns a vector of unconstrained parameters give the constrained parameters. @@ -284,6 +292,7 @@ StanModel <- R6::R6Class("StanModel", lib = NA, lib_name = NA, model = NA, + seed = NA, finalize = function() { .C("bs_destruct_R", as.raw(private$model), @@ -315,13 +324,15 @@ StanRNG <- R6::R6Class("StanRNG", #' @description #' Create a StanRng #' @param lib_name The name of the Stan dynamic library. - #' @param rng_seed The seed for the RNG. + #' @param seed The seed for the RNG. + #' @param chain_id The chain ID for the RNG. #' @return A new StanRNG. - initialize = function(lib_name, rng_seed) { + initialize = function(lib_name, seed, chain_id) { private$lib_name <- lib_name vars <- .C("bs_construct_rng_R", - as.integer(rng_seed), + as.integer(seed), + as.integer(chain_id), ptr_out = raw(8), err_msg = as.character(""), err_ptr = raw(8), diff --git a/R/man/StanModel.Rd b/R/man/StanModel.Rd index 466916cf..89894be9 100644 --- a/R/man/StanModel.Rd +++ b/R/man/StanModel.Rd @@ -39,7 +39,7 @@ as well as constraining and unconstraining transforms. \subsection{Method \code{new()}}{ Create a Stan Model instance. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{StanModel$new(lib, data, rng_seed)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{StanModel$new(lib, data, seed)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -49,7 +49,7 @@ Create a Stan Model instance. \item{\code{data}}{Either a JSON string literal or a path to a data file in JSON format ending in ".json".} -\item{\code{rng_seed}}{Seed for the RNG in the model object.} +\item{\code{seed}}{Seed for the RNG in the model object.} } \if{html}{\out{}} } @@ -177,7 +177,7 @@ See also \code{StanModel$param_unconstrain()}, the inverse of this function. theta_unc, include_tp = FALSE, include_gq = FALSE, - seed, + chain_id, rng )}\if{html}{\out{}} } @@ -191,8 +191,9 @@ See also \code{StanModel$param_unconstrain()}, the inverse of this function. \item{\code{include_gq}}{Whether to also output the generated quantities of the model.} -\item{\code{seed}}{The seed for the random number generator. One of -\code{rng} or \code{seed} must be specified if \code{include_gq} is \code{True}.} +\item{\code{chain_id}}{A chain ID used to offset a PRNG seeded with the model's seed +which should be unique between calls. One of \code{rng} or \code{seed} must be specified +if \code{include_gq} is \code{True}.} \item{\code{rng}}{The random number generator to use. See \code{StanModel$new_rng()}. One of \code{rng} or \code{seed} must be specified if \code{include_gq} is \code{True}.} @@ -209,13 +210,17 @@ The constrained parameters of the model. \subsection{Method \code{new_rng()}}{ Create a new persistent PRNG object for use in \code{param_constrain()}. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{StanModel$new_rng(seed)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{StanModel$new_rng(chain_id, seed)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ -\item{\code{seed}}{The seed for the PRNG.} +\item{\code{chain_id}}{Identifier for a sequence in the RNG. This should be made +a distinct number for each PRNG created with the same seed +(for example, 1:N for N PRNGS).} + +\item{\code{seed}}{The seed for the PRNG. If this is not specified, the model's seed is used.} } \if{html}{\out{
}} } diff --git a/R/man/StanRNG.Rd b/R/man/StanRNG.Rd index d0219d38..f0af6940 100644 --- a/R/man/StanRNG.Rd +++ b/R/man/StanRNG.Rd @@ -31,7 +31,7 @@ RNG object for use with \code{StanModel$param_constrain()} \subsection{Method \code{new()}}{ Create a StanRng \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{StanRNG$new(lib_name, rng_seed)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{StanRNG$new(lib_name, seed, chain_id)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -39,7 +39,9 @@ Create a StanRng \describe{ \item{\code{lib_name}}{The name of the Stan dynamic library.} -\item{\code{rng_seed}}{The seed for the RNG.} +\item{\code{seed}}{The seed for the RNG.} + +\item{\code{chain_id}}{The chain ID for the RNG.} } \if{html}{\out{}} } diff --git a/R/tests/testthat/test-bridgestan.R b/R/tests/testthat/test-bridgestan.R index e1b9da08..97cfa683 100644 --- a/R/tests/testthat/test-bridgestan.R +++ b/R/tests/testthat/test-bridgestan.R @@ -115,11 +115,11 @@ test_that("param_constrain handles rng arguments", { expect_equal(4, length(full$param_constrain(c(1.2), include_tp=TRUE, include_gq=TRUE, rng=rng))) # check reproducibility - expect_equal(full$param_constrain(c(1.2), include_gq=TRUE, seed=1234), - full$param_constrain(c(1.2), include_gq=TRUE, seed=1234)) + expect_equal(full$param_constrain(c(1.2), include_gq=TRUE, chain_id=3), + full$param_constrain(c(1.2), include_gq=TRUE, chain_id=3)) # require at least one present - expect_error(full$param_constrain(c(1.2), include_gq=TRUE), "seed or rng must be specified") + expect_error(full$param_constrain(c(1.2), include_gq=TRUE), "chain_id or rng must be specified") }) @@ -142,5 +142,5 @@ test_that("param_constrain propagates errors", { m2 <- load_model("throw_gq",include_data=FALSE) m2$param_constrain(c(1.2)) # no error - expect_error(m2$param_constrain(c(1.2), include_gq=TRUE, seed=123), "find this text: gqfails") + expect_error(m2$param_constrain(c(1.2), include_gq=TRUE, chain_id=1), "find this text: gqfails") }) diff --git a/docs/languages/julia.md b/docs/languages/julia.md index 71af5d95..e31b2b29 100644 --- a/docs/languages/julia.md +++ b/docs/languages/julia.md @@ -152,7 +152,7 @@ Return the log density of the specified unconstrained parameters. This calculation drops constant terms that do not depend on the parameters if `propto` is `true` and includes change of variables terms for constrained parameters if `jacobian` is `true`. -source
+source
# **`BridgeStan.log_density_gradient`** — *Function*. @@ -170,7 +170,7 @@ This calculation drops constant terms that do not depend on the parameters if `p This allocates new memory for the gradient output each call. See `log_density_gradient!` for a version which allows re-using existing memory. -source
+source
# **`BridgeStan.log_density_hessian`** — *Function*. @@ -188,7 +188,7 @@ This calculation drops constant terms that do not depend on the parameters if `p This allocates new memory for the gradient and Hessian output each call. See `log_density_gradient!` for a version which allows re-using existing memory. -source
+source
# **`BridgeStan.param_constrain`** — *Function*. @@ -196,19 +196,19 @@ This allocates new memory for the gradient and Hessian output each call. See `lo ```julia -param_constrain(sm, theta_unc, out; include_tp=false, include_gq=false, seed=nothing, rng=nothing) +param_constrain(sm, theta_unc, out; include_tp=false, include_gq=false, chain_id=nothing, rng=nothing) ``` Returns a vector constrained parameters given unconstrained parameters. Additionally (if `include_tp` and `include_gq` are set, respectively) returns transformed parameters and generated quantities. -If `include_gq` is set, then either `seed` or `rng` must be provided. See `StanRNG` for details on how to construct persistent RNGs. +If `include_gq` is set, then either `chain_id` or `rng` must be provided. `chain_id` specifies an offset in a PRNG seeded with the model's seed which should be unique between calls. See `StanRNG` for details on how to construct persistent RNGs. This allocates new memory for the output each call. See `param_constrain!` for a version which allows re-using existing memory. This is the inverse of `param_unconstrain`. -source
+source
# **`BridgeStan.param_unconstrain`** — *Function*. @@ -228,7 +228,7 @@ This allocates new memory for the output each call. See `param_unconstrain!` for This is the inverse of `param_constrain`. -source
+source
# **`BridgeStan.param_unconstrain_json`** — *Function*. @@ -246,7 +246,7 @@ The JSON is expected to be in the [JSON Format for CmdStan](https://mc-stan.org/ This allocates new memory for the output each call. See `param_unconstrain_json!` for a version which allows re-using existing memory. -source
+source
# **`BridgeStan.name`** — *Function*. @@ -260,7 +260,7 @@ name(sm) Return the name of the model `sm` -source
+source
# **`BridgeStan.model_info`** — *Function*. @@ -276,7 +276,7 @@ Return information about the model `sm`. This includes the Stan version and important compiler flags. -source
+source
# **`BridgeStan.param_num`** — *Function*. @@ -292,7 +292,7 @@ Return the number of (constrained) parameters in the model. This is the total of all the sizes of items declared in the `parameters` block of the model. If `include_tp` or `include_gq` are true, items declared in the `transformed parameters` and `generate quantities` blocks are included, respectively. -source
+source
# **`BridgeStan.param_unc_num`** — *Function*. @@ -308,7 +308,7 @@ Return the number of unconstrained parameters in the model. This function is mainly different from `param_num` when variables are declared with constraints. For example, `simplex[5]` has a constrained size of 5, but an unconstrained size of 4. -source
+source
# **`BridgeStan.param_names`** — *Function*. @@ -326,7 +326,7 @@ For containers, indexes are separated by periods (.). For example, the scalar `a` has indexed name `"a"`, the vector entry `a[1]` has indexed name `"a.1"` and the matrix entry `a[2, 3]` has indexed names `"a.2.3"`. Parameter order of the output is column major and more generally last-index major for containers. -source
+source
# **`BridgeStan.param_unc_names`** — *Function*. @@ -342,7 +342,7 @@ Return the indexed names of the unconstrained parameters. For example, a scalar unconstrained parameter `b` has indexed name `b` and a vector entry `b[3]` has indexed name `b.3`. -source
+source
# **`BridgeStan.log_density_gradient!`** — *Function*. @@ -360,7 +360,7 @@ This calculation drops constant terms that do not depend on the parameters if `p The gradient is stored in the vector `out`, and a reference is returned. See `log_density_gradient` for a version which allocates fresh memory. -source
+source
# **`BridgeStan.log_density_hessian!`** — *Function*. @@ -378,7 +378,7 @@ This calculation drops constant terms that do not depend on the parameters if `p The gradient is stored in the vector `out_grad` and the Hessian is stored in `out_hess` and references are returned. See `log_density_hessian` for a version which allocates fresh memory. -source
+source
# **`BridgeStan.param_constrain!`** — *Function*. @@ -386,19 +386,19 @@ The gradient is stored in the vector `out_grad` and the Hessian is stored in `ou ```julia -param_constrain!(sm, theta_unc, out; include_tp=false, include_gq=false, seed=nothing, rng=nothing) +param_constrain!(sm, theta_unc, out; include_tp=false, include_gq=false, chain_id=nothing, rng=nothing) ``` Returns a vector constrained parameters given unconstrained parameters. Additionally (if `include_tp` and `include_gq` are set, respectively) returns transformed parameters and generated quantities. -If `include_gq` is set, then either `seed` or `rng` must be provided. See `StanRNG` for details on how to construct persistent RNGs. +If `include_gq` is set, then either `chain_id` or `rng` must be provided. `chain_id` specifies an offset in a PRNG seeded with the model's seed which should be unique between calls. See `StanRNG` for details on how to construct persistent RNGs. The result is stored in the vector `out`, and a reference is returned. See `param_constrain` for a version which allocates fresh memory. This is the inverse of `param_unconstrain!`. -source
+source
# **`BridgeStan.param_unconstrain!`** — *Function*. @@ -418,7 +418,7 @@ The result is stored in the vector `out`, and a reference is returned. See `para This is the inverse of `param_constrain!`. -source
+source
# **`BridgeStan.param_unconstrain_json!`** — *Function*. @@ -436,7 +436,7 @@ The JSON is expected to be in the [JSON Format for CmdStan](https://mc-stan.org/ The result is stored in the vector `out`, and a reference is returned. See `param_unconstrain_json` for a version which allocates fresh memory. -source
+source
# **`BridgeStan.StanRNG`** — *Type*. @@ -444,15 +444,17 @@ The result is stored in the vector `out`, and a reference is returned. See `para ```julia -StanRNG(sm::StanModel, seed) +StanRNG(sm::StanModel, chain_id; seed=nothing) ``` -Construct a StanRNG instance from a `StanModel` instance and a seed. This can be used in the `param_constrain` and `param_constrain!` methods when using the generated quantities block. +Construct a StanRNG instance from a `StanModel` instance and a chain ID. This ID serves as an offset in the PRNG stream and should be unique for each PRNG created with the same seed. A seed can be supplied to use in place of the model's seed, which is used by default. + +This can be used in the `param_constrain` and `param_constrain!` methods when using the generated quantities block. This object is not thread-safe, one should be created per thread. -source
+source
diff --git a/docs/languages/r.md b/docs/languages/r.md index 3d39b854..20c5199e 100644 --- a/docs/languages/r.md +++ b/docs/languages/r.md @@ -217,7 +217,7 @@ of this function. _Usage_ ```R -StanModel$param_constrain(theta_unc, include_tp = FALSE, include_gq = FALSE, seed, rng) +StanModel$param_constrain(theta_unc, include_tp = FALSE, include_gq = FALSE, chain_id, rng) ``` @@ -231,11 +231,11 @@ _Arguments_ - `include_gq` Whether to also output the generated quantities of the model. - - `seed` Seed for the RNG in the model object. One of - `rng` or `seed` must be specified if `include_gq` is `True` + - `chain_id` A chain ID used to offset a PRNG seeded with the model's seed + which should be unique between calls. One of `rng` or `chain_id` must be specified if `include_gq` is `True` - `rng` StanRNG to use in the model object. See `StanModel$new_rng()`. - One of `rng` or `seed` must be specified if `include_gq` is `True` + One of `rng` or `chain_id` must be specified if `include_gq` is `True` _Returns_ @@ -251,13 +251,16 @@ Create a new persistent PRNG object for use in `param_constrain()`. _Usage_ ```R -StanModel$new_rng(seed) +StanModel$new_rng(chain_id, seed) ``` _Arguments_ - - `seed` The seed for the PRNG. + - `chain_id` Identifier for a sequence in the RNG. This should be made a distinct number for each PRNG created with the sane seed (for example, 1:N for N PRNGS). + + - `seed` The seed for the PRNG. If this is not specified, the model's seed is used. + _Returns_ diff --git a/julia/src/model.jl b/julia/src/model.jl index 4abf914e..88d40117 100644 --- a/julia/src/model.jl +++ b/julia/src/model.jl @@ -90,9 +90,13 @@ mutable struct StanModel end """ - StanRNG(sm::StanModel, seed) + StanRNG(sm::StanModel, chain_id; seed=nothing) + +Construct a StanRNG instance from a `StanModel` instance and a chain ID. +This ID serves as an offset in the PRNG stream and should be unique for each +PRNG created with the same seed. +A seed can be supplied to use in place of the model's seed, which is used by default. -Construct a StanRNG instance from a `StanModel` instance and a seed. This can be used in the `param_constrain` and `param_constrain!` methods when using the generated quantities block. @@ -102,24 +106,26 @@ mutable struct StanRNG lib::Ptr{Nothing} rng::Ptr{StanRNGStruct} seed::UInt32 + chain_id::UInt32 - function StanRNG(sm::StanModel, seed) - seed = convert(UInt32, seed) - + function StanRNG(sm::StanModel, chain_id; seed=nothing) + seed = convert(UInt32, if isnothing(seed) sm.seed else seed end) + chain_id = convert(UInt32, chain_id) err = Ref{Cstring}() rng = ccall( Libc.Libdl.dlsym(sm.lib, "bs_construct_rng"), Ptr{StanModelStruct}, - (UInt32, Ref{Cstring}), + (UInt32, UInt32, Ref{Cstring}), seed, + chain_id, err, ) if rng == C_NULL error(_handle_error(sm.lib, err, "bs_construct_rng")) end - stanrng = new(sm.lib, rng, seed) + stanrng = new(sm.lib, rng, seed, chain_id) function f(stanrng) ccall( @@ -250,13 +256,15 @@ function param_unc_names(sm::StanModel) end """ - param_constrain!(sm, theta_unc, out; include_tp=false, include_gq=false, seed=nothing, rng=nothing) + param_constrain!(sm, theta_unc, out; include_tp=false, include_gq=false, chain_id=nothing, rng=nothing) Returns a vector constrained parameters given unconstrained parameters. Additionally (if `include_tp` and `include_gq` are set, respectively) returns transformed parameters and generated quantities. -If `include_gq` is set, then either `seed` or `rng` must be provided. +If `include_gq` is set, then either `chain_id` or `rng` must be provided. +`chain_id` specifies an offset in a PRNG seeded with the model's +seed which should be unique between calls. See `StanRNG` for details on how to construct persistent RNGs. The result is stored in the vector `out`, and a reference is returned. See @@ -270,7 +278,7 @@ function param_constrain!( out::Vector{Float64}; include_tp = false, include_gq = false, - seed::Union{Int, Nothing} = nothing, + chain_id::Union{Int, Nothing} = nothing, rng::Union{StanRNG, Nothing} = nothing, ) dims = param_num(sm; include_tp = include_tp, include_gq = include_gq) @@ -280,15 +288,15 @@ function param_constrain!( ) end - if seed === nothing && rng === nothing + if chain_id === nothing && rng === nothing if include_gq throw( ArgumentError( - "Must provide either a seed or an RNG when including generated quantities", + "Must provide either a chain_id or an RNG when including generated quantities", ), ) else - seed = 0 + chain_id = 0 end end @@ -309,9 +317,9 @@ function param_constrain!( err, ) else - seed = convert(UInt32, seed) + chain_id = convert(UInt32, chain_id) rc = ccall( - Libc.Libdl.dlsym(sm.lib, "bs_param_constrain_seed"), + Libc.Libdl.dlsym(sm.lib, "bs_param_constrain_id"), Cint, (Ptr{StanModelStruct}, Cint, Cint, Ref{Cdouble}, Ref{Cdouble}, Cuint, Ref{Cstring}), sm.stanmodel, @@ -319,7 +327,7 @@ function param_constrain!( include_gq, theta_unc, out, - seed, + chain_id, err, ) end @@ -330,13 +338,15 @@ function param_constrain!( end """ - param_constrain(sm, theta_unc, out; include_tp=false, include_gq=false, seed=nothing, rng=nothing) + param_constrain(sm, theta_unc, out; include_tp=false, include_gq=false, chain_id=nothing, rng=nothing) Returns a vector constrained parameters given unconstrained parameters. Additionally (if `include_tp` and `include_gq` are set, respectively) returns transformed parameters and generated quantities. -If `include_gq` is set, then either `seed` or `rng` must be provided. +If `include_gq` is set, then either `chain_id` or `rng` must be provided. +`chain_id` specifies an offset in a PRNG seeded with the model's +seed which should be unique between calls. See `StanRNG` for details on how to construct persistent RNGs. This allocates new memory for the output each call. @@ -350,11 +360,11 @@ function param_constrain( theta_unc::Vector{Float64}; include_tp = false, include_gq = false, - seed::Union{Int, Nothing} = nothing, + chain_id::Union{Int, Nothing} = nothing, rng::Union{StanRNG, Nothing} = nothing, ) out = zeros(param_num(sm, include_tp = include_tp, include_gq = include_gq)) - param_constrain!(sm, theta_unc, out; include_tp = include_tp, include_gq = include_gq, seed=seed, rng=rng) + param_constrain!(sm, theta_unc, out; include_tp = include_tp, include_gq = include_gq, chain_id=chain_id, rng=rng) end """ diff --git a/julia/test/model_tests.jl b/julia/test/model_tests.jl index feed1660..48248ac0 100644 --- a/julia/test/model_tests.jl +++ b/julia/test/model_tests.jl @@ -176,7 +176,7 @@ end model2 = load_test_model("full", false) - rng = StanRNG(model2, 1234) + rng = StanRNG(model2, 2; seed=1234) @test 1 == length(BridgeStan.param_constrain(model2, a)) @test 2 == length(BridgeStan.param_constrain(model2, a; include_tp = true)) @test 3 == length(BridgeStan.param_constrain(model2, a; include_gq = true, rng=rng)) @@ -185,8 +185,8 @@ end ) # reproducibility - @test isapprox(BridgeStan.param_constrain(model2, a; include_gq = true, seed=1234), - BridgeStan.param_constrain(model2, a; include_gq = true, seed=1234)) + @test isapprox(BridgeStan.param_constrain(model2, a; include_gq = true, chain_id=3), + BridgeStan.param_constrain(model2, a; include_gq = true, chain_id=3)) # no seed or rng provided @test_throws ArgumentError BridgeStan.param_constrain(model2, a; include_gq = true) @@ -207,7 +207,7 @@ end model4, y; include_gq = true, - seed=1234 + chain_id=1, ) end diff --git a/python/bridgestan/model.py b/python/bridgestan/model.py index a7457927..bd77c725 100644 --- a/python/bridgestan/model.py +++ b/python/bridgestan/model.py @@ -45,8 +45,7 @@ def __init__( :param model_lib: A system path to compiled shared object. :param model_data: Either a JSON string literal or a system path to a data file in JSON format ending in ``.json``. - :param seed: A pseudo random number generator seed. This is only used - if the model has RNG usage in the ``transformed data`` block. + :param seed: A pseudo random number generator seed. :raises FileNotFoundError or PermissionError: If ``model_lib`` is not readable or ``model_data`` is specified and not a path to a readable file. :raises RuntimeError: If there is an error instantiating the @@ -118,9 +117,9 @@ def __init__( star_star_char, ] - self._param_constrain_seed = self.stanlib.bs_param_constrain_seed - self._param_constrain_seed.restype = ctypes.c_int - self._param_constrain_seed.argtypes = [ + self._param_constrain_id = self.stanlib.bs_param_constrain_id + self._param_constrain_id.restype = ctypes.c_int + self._param_constrain_id.argtypes = [ ctypes.c_void_p, ctypes.c_int, ctypes.c_int, @@ -230,7 +229,7 @@ def __del__(self) -> None: def __repr__(self) -> str: data = f"{self.data_path!r}, " if self.data_path else "" - return f"StanModel({self.lib_path!r}, {data}seed={self.seed})" + return f"StanModel({self.lib_path!r}, {data}, seed={self.seed})" def name(self) -> str: """ @@ -309,7 +308,7 @@ def param_constrain( include_tp: bool = False, include_gq: bool = False, out: Optional[FloatArray] = None, - seed: Optional[int] = None, + chain_id: Optional[int] = None, rng: Optional["StanRNG"] = None, ) -> FloatArray: """ @@ -326,28 +325,28 @@ def param_constrain( provided, it must have shape `(D, )`, where `D` is the number of constrained parameters. If not provided or `None`, a freshly allocated array is returned. - :param seed: A pseudo random number generator seed. One of - ``rng`` or ``seed`` must be specified if ``include_gq`` - is ``True``. + :param chain_id: A chain ID used to offset a PRNG seeded with the model's + seed which should be unique between calls. One of ``rng`` or + ``chain_id`` must be specified if ``include_gq`` is ``True``. :param rng: A ``StanRNG`` object to use for generating random numbers, see :meth:`~StanModel.new_rng`. - One of ``rng`` or ``seed`` must be specified if ``include_gq`` + One of ``rng`` or ``chain_id`` must be specified if ``include_gq`` is ``True``. :return: The constrained parameter array. :raises ValueError: If ``out`` is specified and is not the same shape as the return. - :raises ValueError: If neither ``rng`` nor ``seed`` is specified + :raises ValueError: If neither ``rng`` nor ``chain_id`` is specified and ``include_gq`` is ``True``. :raises RuntimeError: If the C++ Stan model throws an exception. """ - if seed is None and rng is None: + if chain_id is None and rng is None: if include_gq: raise ValueError( - "Error: must specify rng or seed when including generated quantities" + "Error: must specify rng or chain_id when including generated quantities" ) else: - # neither specified, but not doing gq, so use a fixed seed - seed = 0 + # neither specified, but not doing gq, so use a fixed chain id + chain_id = 0 dims = self.param_num(include_tp=include_tp, include_gq=include_gq) if out is None: @@ -369,13 +368,13 @@ def param_constrain( err, ) else: - rc = self._param_constrain_seed( + rc = self._param_constrain_id( self.model_rng, int(include_tp), int(include_gq), theta_unc, out, - seed, + chain_id, err, ) @@ -383,14 +382,18 @@ def param_constrain( raise self._handle_error(err.contents, "param_constrain") return out - def new_rng(self, seed: int) -> "StanRNG": + def new_rng(self, chain_id: int, *, seed=None) -> "StanRNG": """ Return a new PRNG for use in :meth:`~StanModel.param_constrain`. - :param seed: The seed for the PRNG. + :param chain_id: Identifier for a sequence in the RNG. This should be made + a distinct number for each PRNG created with the same seed + (for example, 1:N for N PRNGS). + :param seed: A seed for the PRNG. If not specified, the + model's seed is used. :return: A new PRNG for the model. """ - return StanRNG(self.stanlib, seed) + return StanRNG(self.stanlib, seed or self.seed, chain_id) def param_unconstrain( self, theta: FloatArray, *, out: Optional[FloatArray] = None @@ -606,13 +609,13 @@ def _handle_error(self, err: ctypes.c_char_p, method: str) -> Exception: class StanRNG: - def __init__(self, lib: ctypes.CDLL, seed: int) -> None: + def __init__(self, lib: ctypes.CDLL, seed: int, chain_id: int) -> None: self.stanlib = lib construct = self.stanlib.bs_construct_rng construct.restype = ctypes.c_void_p - construct.argtypes = [ctypes.c_uint, star_star_char] - self.ptr = construct(seed, None) + construct.argtypes = [ctypes.c_uint, ctypes.c_uint, star_star_char] + self.ptr = construct(seed, chain_id, None) if not self.ptr: raise RuntimeError("Failed to construct RNG.") diff --git a/python/test/test_stanmodel.py b/python/test/test_stanmodel.py index 32914d77..4d319ac1 100644 --- a/python/test/test_stanmodel.py +++ b/python/test/test_stanmodel.py @@ -208,7 +208,7 @@ def test_param_constrain(): full_so = str(STAN_FOLDER / "full" / "full_model.so") bridge2 = bs.StanModel(full_so) - rng = bridge2.new_rng(seed=1234) + rng = bridge2.new_rng(chain_id=1, seed=1234) np.testing.assert_equal(1, bridge2.param_constrain(a).size) np.testing.assert_equal(2, bridge2.param_constrain(a, include_tp=True).size) @@ -221,8 +221,8 @@ def test_param_constrain(): # reproducibility test np.testing.assert_equal( - bridge2.param_constrain(a, include_gq=True, seed=1234), - bridge2.param_constrain(a, include_gq=True, seed=1234), + bridge2.param_constrain(a, include_gq=True, chain_id=4), + bridge2.param_constrain(a, include_gq=True, chain_id=4), ) # test error if neither seed or rng is provided @@ -251,7 +251,7 @@ def test_param_constrain(): bridge3 = bs.StanModel(throw_gq_so) bridge3.param_constrain(y, include_gq=False) with pytest.raises(RuntimeError, match="find this text: gqfails"): - bridge3.param_constrain(y, include_gq=True, seed=123) + bridge3.param_constrain(y, include_gq=True, chain_id=1) def test_param_unconstrain(): diff --git a/src/bridgestan.cpp b/src/bridgestan.cpp index 18cc76ff..2d590e2c 100644 --- a/src/bridgestan.cpp +++ b/src/bridgestan.cpp @@ -85,11 +85,11 @@ int bs_param_constrain(const bs_model* mr, bool include_tp, bool include_gq, return 1; } -int bs_param_constrain_seed(const bs_model* mr, bool include_tp, +int bs_param_constrain_id(const bs_model* mr, bool include_tp, bool include_gq, const double* theta_unc, - double* theta, unsigned int seed, + double* theta, unsigned int chain_id, char** error_msg) { - bs_rng rng(seed); + bs_rng rng(mr->seed(), chain_id); return bs_param_constrain(mr, include_tp, include_gq, theta_unc, theta, &rng, error_msg); } @@ -208,9 +208,9 @@ int bs_log_density_hessian(const bs_model* mr, bool propto, bool jacobian, return -1; } -bs_rng* bs_construct_rng(unsigned int seed, char** error_msg) { +bs_rng* bs_construct_rng(unsigned int seed, unsigned int chain_id, char** error_msg) { try { - return new bs_rng(seed); + return new bs_rng(seed, chain_id); } catch (const std::exception& e) { if (error_msg) { std::stringstream error; diff --git a/src/bridgestan.h b/src/bridgestan.h index 39cae4e5..16cee57f 100644 --- a/src/bridgestan.h +++ b/src/bridgestan.h @@ -23,7 +23,8 @@ typedef int bool; * @return pointer to constructed model or `nullptr` if construction * fails */ -bs_model* bs_construct(const char* data_file, unsigned int seed, char** error_msg); +bs_model* bs_construct(const char* data_file, unsigned int seed, + char** error_msg); /** * Destroy the model and return 0 for success and -1 if there is an @@ -163,25 +164,26 @@ int bs_param_constrain(const bs_model* mr, bool include_tp, bool include_gq, * in the Stan program, with multivariate parameters given in * last-index-major order. * - * This version accepts a seed which is used to create a fresh PRNG - * which lives only for the duration of this call. + * This version accepts a chain_id which is used to create a PRNG + * offset from the model's seed which lives only for the duration + * of this call. * * @param[in] mr pointer to model and RNG structure * @param[in] include_tp `true` to include transformed parameters * @param[in] include_gq `true` to include generated quantities * @param[in] theta_unc sequence of unconstrained parameters * @param[out] theta sequence of constrained parameters - * @param[in] seed seed for pseudorandom number generator which will be created - * and destroyed during this call. See `bs_param_constrain` for an option with a - * persistent RNG. + * @param[in] chain_id offset for pseudorandom number generator which will be + * created and destroyed during this call. seeded with model seed. See + * `bs_param_constrain` for an option with a persistent RNG. * @param[out] error_msg a pointer to a string that will be allocated if there * is an error. This must later be freed by calling `bs_free_error_msg`. * @return code 0 if successful and code -1 if there is an exception * in the underlying Stan code */ -int bs_param_constrain_seed(const bs_model* mr, bool include_tp, - bool include_gq, const double* theta_unc, - double* theta, unsigned int seed, char** error_msg); +int bs_param_constrain_id(const bs_model* mr, bool include_tp, bool include_gq, + const double* theta_unc, double* theta, + unsigned int chain_id, char** error_msg); /** * Set the sequence of unconstrained parameters based on the @@ -294,15 +296,17 @@ int bs_log_density_hessian(const bs_model* mr, bool propto, bool jacobian, double* hessian, char** error_msg); /** - * Construct an RNG object to be used in `bs_param_constrain`. + * Construct an PRNG object to be used in `bs_param_constrain`. * This object is not thread safe and should be constructed and * destructed for each thread. * * @param[in] seed seed for the RNG + * @param[in] chain_id identifier for the current sequence of PRNG draws * @param[out] error_msg a pointer to a string that will be allocated if there * is an error. This must later be freed by calling `bs_free_error_msg`. */ -bs_rng* bs_construct_rng(unsigned int seed, char** error_msg); +bs_rng* bs_construct_rng(unsigned int seed, unsigned int chain_id, + char** error_msg); /** * Destruct an RNG object. diff --git a/src/bridgestanR.cpp b/src/bridgestanR.cpp index e1e0f246..cc5cdc88 100644 --- a/src/bridgestanR.cpp +++ b/src/bridgestanR.cpp @@ -39,12 +39,12 @@ void bs_param_constrain_R(bs_model** model, int* include_tp, int* include_gq, theta, *rng, err_msg); *err_ptr = static_cast(*err_msg); } -void bs_param_constrain_seed_R(bs_model** model, int* include_tp, +void bs_param_constrain_id_R(bs_model** model, int* include_tp, int* include_gq, const double* theta_unc, - double* theta, int* seed, int* return_code, + double* theta, int* chain_id, int* return_code, char** err_msg, void** err_ptr) { - *return_code = bs_param_constrain_seed(*model, *include_tp, *include_gq, - theta_unc, theta, *seed, err_msg); + *return_code = bs_param_constrain_id(*model, *include_tp, *include_gq, + theta_unc, theta, *chain_id, err_msg); *err_ptr = static_cast(*err_msg); } void bs_param_unconstrain_R(bs_model** model, const double* theta, @@ -82,9 +82,9 @@ void bs_log_density_hessian_R(bs_model** model, int* propto, int* jacobian, val, grad, hess, err_msg); *err_ptr = static_cast(*err_msg); } -void bs_construct_rng_R(int* seed, bs_rng** ptr_out, char** err_msg, +void bs_construct_rng_R(int* seed, int* chain_id, bs_rng** ptr_out, char** err_msg, void** err_ptr) { - *ptr_out = bs_construct_rng(*seed, err_msg); + *ptr_out = bs_construct_rng(*seed, *chain_id, err_msg); *err_ptr = static_cast(*err_msg); } void bs_destruct_rng_R(bs_rng** rng, int* return_code) { diff --git a/src/bridgestanR.h b/src/bridgestanR.h index 811580b5..f14f4481 100644 --- a/src/bridgestanR.h +++ b/src/bridgestanR.h @@ -41,9 +41,9 @@ void bs_param_constrain_R(bs_model** model, int* include_tp, int* include_gq, const double* theta_unc, double* theta, bs_rng** rng, int* return_code, char** err_msg, void** err_ptr); -void bs_param_constrain_seed_R(bs_model** model, int* include_tp, +void bs_param_constrain_id_R(bs_model** model, int* include_tp, int* include_gq, const double* theta_unc, - double* theta, int* seed, int* return_code, + double* theta, int* chain_id, int* return_code, char** err_msg, void** err_ptr); void bs_param_unconstrain_R(bs_model** model, const double* theta, @@ -68,7 +68,7 @@ void bs_log_density_hessian_R(bs_model** model, int* propto, int* jacobian, double* grad, double* hess, int* return_code, char** err_msg, void** err_ptr); -void bs_construct_rng_R(int* seed, bs_rng** ptr_out, char** err_msg, +void bs_construct_rng_R(int* seed, int* chain_id, bs_rng** ptr_out, char** err_msg, void** err_ptr); void bs_destruct_rng_R(bs_rng** rng, int* return_code); diff --git a/src/model_rng.cpp b/src/model_rng.cpp index 3955cadb..887ae6c5 100644 --- a/src/model_rng.cpp +++ b/src/model_rng.cpp @@ -78,6 +78,8 @@ bs_model::bs_model(const char* data_file, unsigned int seed) { } } + seed_ = seed; + std::string model_name = model_->model_name(); const char* model_name_c = model_name.c_str(); name_ = strdup(model_name_c); @@ -167,6 +169,8 @@ bs_model::~bs_model() { const char* bs_model::name() const { return name_; } +unsigned int bs_model::seed() const { return seed_; } + const char* bs_model::model_info() const { return model_info_; } const char* bs_model::param_names(bool include_tp, bool include_gq) const { @@ -323,3 +327,7 @@ void bs_model::log_density_hessian(bool propto, bool jacobian, Eigen::VectorXd::Map(grad, N) = grad_vec; Eigen::MatrixXd::Map(hessian, N, N) = hess_mat; } + +bs_rng::bs_rng(unsigned int seed, unsigned int chain_id) { + rng_ = stan::services::util::create_rng(seed, chain_id + 1); +} diff --git a/src/model_rng.hpp b/src/model_rng.hpp index d83a4f2d..031e2dfd 100644 --- a/src/model_rng.hpp +++ b/src/model_rng.hpp @@ -38,6 +38,14 @@ class bs_model { */ const char* name() const; + /** + * Return the pseudorandom number generator seed + * used during model construction. + * + * @return seed + */ + unsigned int seed() const; + /** * Return information about the compiled model. This class manages the * memory, so the returned string should not be freed. @@ -182,6 +190,9 @@ class bs_model { /** Stan model */ stan::model::model_base* model_; + /** RNG seed provided during model creation */ + unsigned int seed_; + /** name of the Stan model */ char* name_ = nullptr; @@ -230,7 +241,7 @@ class bs_model { */ class bs_rng { public: - bs_rng(unsigned int seed) : rng_(seed) { rng_.discard(1); } + bs_rng(unsigned int seed, unsigned int chain_id); boost::ecuyer1988 rng_; };