diff --git a/R/R/bridgestan.R b/R/R/bridgestan.R
index d68d4e2e..8ba70df2 100755
--- a/R/R/bridgestan.R
+++ b/R/R/bridgestan.R
@@ -11,15 +11,16 @@ StanModel <- R6::R6Class("StanModel",
#' Create a Stan Model instance.
#' @param lib A path to a compiled BridgeStan Shared Object file.
#' @param data Either a JSON string literal or a path to a data file in JSON format ending in ".json".
- #' @param rng_seed Seed for the RNG in the model object.
+ #' @param seed Seed for the RNG in the model object.
#' @return A new StanModel.
- initialize = function(lib, data, rng_seed) {
+ initialize = function(lib, data, seed) {
if (.Platform$OS.type == "windows"){
lib_old <- lib
lib <- paste0(tools::file_path_sans_ext(lib), ".dll")
file.copy(from=lib_old, to=lib)
}
+ private$seed <- seed
private$lib <- tools::file_path_as_absolute(lib)
private$lib_name <- tools::file_path_sans_ext(basename(lib))
if (is.loaded("construct_R", PACKAGE = private$lib_name)) {
@@ -32,7 +33,7 @@ StanModel <- R6::R6Class("StanModel",
dyn.load(private$lib, PACKAGE = private$lib_name)
ret <- .C("bs_construct_R",
- as.character(data), as.integer(rng_seed),
+ as.character(data), as.integer(seed),
ptr_out = raw(8),
err_msg = as.character(""),
err_ptr = raw(8),
@@ -124,17 +125,18 @@ StanModel <- R6::R6Class("StanModel",
#' @param theta_unc The vector of unconstrained parameters.
#' @param include_tp Whether to also output the transformed parameters of the model.
#' @param include_gq Whether to also output the generated quantities of the model.
- #' @param seed The seed for the random number generator. One of
- #' `rng` or `seed` must be specified if `include_gq` is `True`.
+ #' @param chain_id A chain ID used to offset a PRNG seeded with the model's seed
+ #' which should be unique between calls. One of `rng` or `seed` must be specified
+ #' if `include_gq` is `True`.
#' @param rng The random number generator to use. See `StanModel$new_rng()`.
#' One of `rng` or `seed` must be specified if `include_gq` is `True`.
#' @return The constrained parameters of the model.
- param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE, seed, rng) {
- if (missing(seed) && missing(rng)){
+ param_constrain = function(theta_unc, include_tp = FALSE, include_gq = FALSE, chain_id, rng) {
+ if (missing(chain_id) && missing(rng)){
if (include_gq) {
- stop("Either seed or rng must be specified if include_gq is True.")
+ stop("Either chain_id or rng must be specified if include_gq is True.")
} else {
- seed = 0
+ chain_id <- 0
}
}
if (!missing(rng)){
@@ -148,10 +150,10 @@ StanModel <- R6::R6Class("StanModel",
PACKAGE = private$lib_name
)
} else {
- vars <- .C("bs_param_constrain_seed_R", as.raw(private$model),
+ vars <- .C("bs_param_constrain_id_R", as.raw(private$model),
as.logical(include_tp), as.logical(include_gq), as.double(theta_unc),
theta = double(self$param_num(include_tp = include_tp, include_gq = include_gq)),
- seed = as.integer(seed),
+ chain_id = as.integer(chain_id),
return_code = as.integer(0),
err_msg = as.character(""),
err_ptr = raw(8),
@@ -166,10 +168,16 @@ StanModel <- R6::R6Class("StanModel",
},
#' @description
#' Create a new persistent PRNG object for use in `param_constrain()`.
- #' @param seed The seed for the PRNG.
+ #' @param chain_id Identifier for a sequence in the RNG. This should be made
+ #' a distinct number for each PRNG created with the same seed
+ #' (for example, 1:N for N PRNGS).
+ #' @param seed The seed for the PRNG. If this is not specified, the model's seed is used.
#' @return A `StanRNG` object.
- new_rng = function(seed) {
- StanRNG$new(private$lib_name, seed)
+ new_rng = function(chain_id, seed) {
+ if (missing(seed)){
+ seed <- private$seed
+ }
+ StanRNG$new(private$lib_name, seed, chain_id)
},
#' @description
#' Returns a vector of unconstrained parameters give the constrained parameters.
@@ -284,6 +292,7 @@ StanModel <- R6::R6Class("StanModel",
lib = NA,
lib_name = NA,
model = NA,
+ seed = NA,
finalize = function() {
.C("bs_destruct_R",
as.raw(private$model),
@@ -315,13 +324,15 @@ StanRNG <- R6::R6Class("StanRNG",
#' @description
#' Create a StanRng
#' @param lib_name The name of the Stan dynamic library.
- #' @param rng_seed The seed for the RNG.
+ #' @param seed The seed for the RNG.
+ #' @param chain_id The chain ID for the RNG.
#' @return A new StanRNG.
- initialize = function(lib_name, rng_seed) {
+ initialize = function(lib_name, seed, chain_id) {
private$lib_name <- lib_name
vars <- .C("bs_construct_rng_R",
- as.integer(rng_seed),
+ as.integer(seed),
+ as.integer(chain_id),
ptr_out = raw(8),
err_msg = as.character(""),
err_ptr = raw(8),
diff --git a/R/man/StanModel.Rd b/R/man/StanModel.Rd
index 466916cf..89894be9 100644
--- a/R/man/StanModel.Rd
+++ b/R/man/StanModel.Rd
@@ -39,7 +39,7 @@ as well as constraining and unconstraining transforms.
\subsection{Method \code{new()}}{
Create a Stan Model instance.
\subsection{Usage}{
-\if{html}{\out{
}}\preformatted{StanModel$new(lib, data, rng_seed)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{StanModel$new(lib, data, seed)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -49,7 +49,7 @@ Create a Stan Model instance.
\item{\code{data}}{Either a JSON string literal or a path to a data file in JSON format ending in ".json".}
-\item{\code{rng_seed}}{Seed for the RNG in the model object.}
+\item{\code{seed}}{Seed for the RNG in the model object.}
}
\if{html}{\out{}}
}
@@ -177,7 +177,7 @@ See also \code{StanModel$param_unconstrain()}, the inverse of this function.
theta_unc,
include_tp = FALSE,
include_gq = FALSE,
- seed,
+ chain_id,
rng
)}\if{html}{\out{}}
}
@@ -191,8 +191,9 @@ See also \code{StanModel$param_unconstrain()}, the inverse of this function.
\item{\code{include_gq}}{Whether to also output the generated quantities of the model.}
-\item{\code{seed}}{The seed for the random number generator. One of
-\code{rng} or \code{seed} must be specified if \code{include_gq} is \code{True}.}
+\item{\code{chain_id}}{A chain ID used to offset a PRNG seeded with the model's seed
+which should be unique between calls. One of \code{rng} or \code{seed} must be specified
+if \code{include_gq} is \code{True}.}
\item{\code{rng}}{The random number generator to use. See \code{StanModel$new_rng()}.
One of \code{rng} or \code{seed} must be specified if \code{include_gq} is \code{True}.}
@@ -209,13 +210,17 @@ The constrained parameters of the model.
\subsection{Method \code{new_rng()}}{
Create a new persistent PRNG object for use in \code{param_constrain()}.
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{StanModel$new_rng(seed)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{StanModel$new_rng(chain_id, seed)}\if{html}{\out{
}}
}
\subsection{Arguments}{
\if{html}{\out{}}
\describe{
-\item{\code{seed}}{The seed for the PRNG.}
+\item{\code{chain_id}}{Identifier for a sequence in the RNG. This should be made
+a distinct number for each PRNG created with the same seed
+(for example, 1:N for N PRNGS).}
+
+\item{\code{seed}}{The seed for the PRNG. If this is not specified, the model's seed is used.}
}
\if{html}{\out{
}}
}
diff --git a/R/man/StanRNG.Rd b/R/man/StanRNG.Rd
index d0219d38..f0af6940 100644
--- a/R/man/StanRNG.Rd
+++ b/R/man/StanRNG.Rd
@@ -31,7 +31,7 @@ RNG object for use with \code{StanModel$param_constrain()}
\subsection{Method \code{new()}}{
Create a StanRng
\subsection{Usage}{
-\if{html}{\out{}}\preformatted{StanRNG$new(lib_name, rng_seed)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{StanRNG$new(lib_name, seed, chain_id)}\if{html}{\out{
}}
}
\subsection{Arguments}{
@@ -39,7 +39,9 @@ Create a StanRng
\describe{
\item{\code{lib_name}}{The name of the Stan dynamic library.}
-\item{\code{rng_seed}}{The seed for the RNG.}
+\item{\code{seed}}{The seed for the RNG.}
+
+\item{\code{chain_id}}{The chain ID for the RNG.}
}
\if{html}{\out{}}
}
diff --git a/R/tests/testthat/test-bridgestan.R b/R/tests/testthat/test-bridgestan.R
index e1b9da08..97cfa683 100644
--- a/R/tests/testthat/test-bridgestan.R
+++ b/R/tests/testthat/test-bridgestan.R
@@ -115,11 +115,11 @@ test_that("param_constrain handles rng arguments", {
expect_equal(4, length(full$param_constrain(c(1.2), include_tp=TRUE, include_gq=TRUE, rng=rng)))
# check reproducibility
- expect_equal(full$param_constrain(c(1.2), include_gq=TRUE, seed=1234),
- full$param_constrain(c(1.2), include_gq=TRUE, seed=1234))
+ expect_equal(full$param_constrain(c(1.2), include_gq=TRUE, chain_id=3),
+ full$param_constrain(c(1.2), include_gq=TRUE, chain_id=3))
# require at least one present
- expect_error(full$param_constrain(c(1.2), include_gq=TRUE), "seed or rng must be specified")
+ expect_error(full$param_constrain(c(1.2), include_gq=TRUE), "chain_id or rng must be specified")
})
@@ -142,5 +142,5 @@ test_that("param_constrain propagates errors", {
m2 <- load_model("throw_gq",include_data=FALSE)
m2$param_constrain(c(1.2)) # no error
- expect_error(m2$param_constrain(c(1.2), include_gq=TRUE, seed=123), "find this text: gqfails")
+ expect_error(m2$param_constrain(c(1.2), include_gq=TRUE, chain_id=1), "find this text: gqfails")
})
diff --git a/docs/languages/julia.md b/docs/languages/julia.md
index 71af5d95..e31b2b29 100644
--- a/docs/languages/julia.md
+++ b/docs/languages/julia.md
@@ -152,7 +152,7 @@ Return the log density of the specified unconstrained parameters.
This calculation drops constant terms that do not depend on the parameters if `propto` is `true` and includes change of variables terms for constrained parameters if `jacobian` is `true`.
-source
+source
#
**`BridgeStan.log_density_gradient`** — *Function*.
@@ -170,7 +170,7 @@ This calculation drops constant terms that do not depend on the parameters if `p
This allocates new memory for the gradient output each call. See `log_density_gradient!` for a version which allows re-using existing memory.
-source
+source
#
**`BridgeStan.log_density_hessian`** — *Function*.
@@ -188,7 +188,7 @@ This calculation drops constant terms that do not depend on the parameters if `p
This allocates new memory for the gradient and Hessian output each call. See `log_density_gradient!` for a version which allows re-using existing memory.
-source
+source
#
**`BridgeStan.param_constrain`** — *Function*.
@@ -196,19 +196,19 @@ This allocates new memory for the gradient and Hessian output each call. See `lo
```julia
-param_constrain(sm, theta_unc, out; include_tp=false, include_gq=false, seed=nothing, rng=nothing)
+param_constrain(sm, theta_unc, out; include_tp=false, include_gq=false, chain_id=nothing, rng=nothing)
```
Returns a vector constrained parameters given unconstrained parameters. Additionally (if `include_tp` and `include_gq` are set, respectively) returns transformed parameters and generated quantities.
-If `include_gq` is set, then either `seed` or `rng` must be provided. See `StanRNG` for details on how to construct persistent RNGs.
+If `include_gq` is set, then either `chain_id` or `rng` must be provided. `chain_id` specifies an offset in a PRNG seeded with the model's seed which should be unique between calls. See `StanRNG` for details on how to construct persistent RNGs.
This allocates new memory for the output each call. See `param_constrain!` for a version which allows re-using existing memory.
This is the inverse of `param_unconstrain`.
-source
+source
#
**`BridgeStan.param_unconstrain`** — *Function*.
@@ -228,7 +228,7 @@ This allocates new memory for the output each call. See `param_unconstrain!` for
This is the inverse of `param_constrain`.
-source
+source
#
**`BridgeStan.param_unconstrain_json`** — *Function*.
@@ -246,7 +246,7 @@ The JSON is expected to be in the [JSON Format for CmdStan](https://mc-stan.org/
This allocates new memory for the output each call. See `param_unconstrain_json!` for a version which allows re-using existing memory.
-source
+source
#
**`BridgeStan.name`** — *Function*.
@@ -260,7 +260,7 @@ name(sm)
Return the name of the model `sm`
-source
+source
#
**`BridgeStan.model_info`** — *Function*.
@@ -276,7 +276,7 @@ Return information about the model `sm`.
This includes the Stan version and important compiler flags.
-source
+source
#
**`BridgeStan.param_num`** — *Function*.
@@ -292,7 +292,7 @@ Return the number of (constrained) parameters in the model.
This is the total of all the sizes of items declared in the `parameters` block of the model. If `include_tp` or `include_gq` are true, items declared in the `transformed parameters` and `generate quantities` blocks are included, respectively.
-source
+source
#
**`BridgeStan.param_unc_num`** — *Function*.
@@ -308,7 +308,7 @@ Return the number of unconstrained parameters in the model.
This function is mainly different from `param_num` when variables are declared with constraints. For example, `simplex[5]` has a constrained size of 5, but an unconstrained size of 4.
-source
+source
#
**`BridgeStan.param_names`** — *Function*.
@@ -326,7 +326,7 @@ For containers, indexes are separated by periods (.).
For example, the scalar `a` has indexed name `"a"`, the vector entry `a[1]` has indexed name `"a.1"` and the matrix entry `a[2, 3]` has indexed names `"a.2.3"`. Parameter order of the output is column major and more generally last-index major for containers.
-source
+source
#
**`BridgeStan.param_unc_names`** — *Function*.
@@ -342,7 +342,7 @@ Return the indexed names of the unconstrained parameters.
For example, a scalar unconstrained parameter `b` has indexed name `b` and a vector entry `b[3]` has indexed name `b.3`.
-source
+source
#
**`BridgeStan.log_density_gradient!`** — *Function*.
@@ -360,7 +360,7 @@ This calculation drops constant terms that do not depend on the parameters if `p
The gradient is stored in the vector `out`, and a reference is returned. See `log_density_gradient` for a version which allocates fresh memory.
-source
+source
#
**`BridgeStan.log_density_hessian!`** — *Function*.
@@ -378,7 +378,7 @@ This calculation drops constant terms that do not depend on the parameters if `p
The gradient is stored in the vector `out_grad` and the Hessian is stored in `out_hess` and references are returned. See `log_density_hessian` for a version which allocates fresh memory.
-source
+source
#
**`BridgeStan.param_constrain!`** — *Function*.
@@ -386,19 +386,19 @@ The gradient is stored in the vector `out_grad` and the Hessian is stored in `ou
```julia
-param_constrain!(sm, theta_unc, out; include_tp=false, include_gq=false, seed=nothing, rng=nothing)
+param_constrain!(sm, theta_unc, out; include_tp=false, include_gq=false, chain_id=nothing, rng=nothing)
```
Returns a vector constrained parameters given unconstrained parameters. Additionally (if `include_tp` and `include_gq` are set, respectively) returns transformed parameters and generated quantities.
-If `include_gq` is set, then either `seed` or `rng` must be provided. See `StanRNG` for details on how to construct persistent RNGs.
+If `include_gq` is set, then either `chain_id` or `rng` must be provided. `chain_id` specifies an offset in a PRNG seeded with the model's seed which should be unique between calls. See `StanRNG` for details on how to construct persistent RNGs.
The result is stored in the vector `out`, and a reference is returned. See `param_constrain` for a version which allocates fresh memory.
This is the inverse of `param_unconstrain!`.
-source
+source
#
**`BridgeStan.param_unconstrain!`** — *Function*.
@@ -418,7 +418,7 @@ The result is stored in the vector `out`, and a reference is returned. See `para
This is the inverse of `param_constrain!`.
-source
+source
#
**`BridgeStan.param_unconstrain_json!`** — *Function*.
@@ -436,7 +436,7 @@ The JSON is expected to be in the [JSON Format for CmdStan](https://mc-stan.org/
The result is stored in the vector `out`, and a reference is returned. See `param_unconstrain_json` for a version which allocates fresh memory.
-source
+source
#
**`BridgeStan.StanRNG`** — *Type*.
@@ -444,15 +444,17 @@ The result is stored in the vector `out`, and a reference is returned. See `para
```julia
-StanRNG(sm::StanModel, seed)
+StanRNG(sm::StanModel, chain_id; seed=nothing)
```
-Construct a StanRNG instance from a `StanModel` instance and a seed. This can be used in the `param_constrain` and `param_constrain!` methods when using the generated quantities block.
+Construct a StanRNG instance from a `StanModel` instance and a chain ID. This ID serves as an offset in the PRNG stream and should be unique for each PRNG created with the same seed. A seed can be supplied to use in place of the model's seed, which is used by default.
+
+This can be used in the `param_constrain` and `param_constrain!` methods when using the generated quantities block.
This object is not thread-safe, one should be created per thread.
-source
+source
diff --git a/docs/languages/r.md b/docs/languages/r.md
index 3d39b854..20c5199e 100644
--- a/docs/languages/r.md
+++ b/docs/languages/r.md
@@ -217,7 +217,7 @@ of this function.
_Usage_
```R
-StanModel$param_constrain(theta_unc, include_tp = FALSE, include_gq = FALSE, seed, rng)
+StanModel$param_constrain(theta_unc, include_tp = FALSE, include_gq = FALSE, chain_id, rng)
```
@@ -231,11 +231,11 @@ _Arguments_
- `include_gq` Whether to also output the generated quantities
of the model.
- - `seed` Seed for the RNG in the model object. One of
- `rng` or `seed` must be specified if `include_gq` is `True`
+ - `chain_id` A chain ID used to offset a PRNG seeded with the model's seed
+ which should be unique between calls. One of `rng` or `chain_id` must be specified if `include_gq` is `True`
- `rng` StanRNG to use in the model object. See `StanModel$new_rng()`.
- One of `rng` or `seed` must be specified if `include_gq` is `True`
+ One of `rng` or `chain_id` must be specified if `include_gq` is `True`
_Returns_
@@ -251,13 +251,16 @@ Create a new persistent PRNG object for use in `param_constrain()`.
_Usage_
```R
-StanModel$new_rng(seed)
+StanModel$new_rng(chain_id, seed)
```
_Arguments_
- - `seed` The seed for the PRNG.
+ - `chain_id` Identifier for a sequence in the RNG. This should be made a distinct number for each PRNG created with the sane seed (for example, 1:N for N PRNGS).
+
+ - `seed` The seed for the PRNG. If this is not specified, the model's seed is used.
+
_Returns_
diff --git a/julia/src/model.jl b/julia/src/model.jl
index 4abf914e..88d40117 100644
--- a/julia/src/model.jl
+++ b/julia/src/model.jl
@@ -90,9 +90,13 @@ mutable struct StanModel
end
"""
- StanRNG(sm::StanModel, seed)
+ StanRNG(sm::StanModel, chain_id; seed=nothing)
+
+Construct a StanRNG instance from a `StanModel` instance and a chain ID.
+This ID serves as an offset in the PRNG stream and should be unique for each
+PRNG created with the same seed.
+A seed can be supplied to use in place of the model's seed, which is used by default.
-Construct a StanRNG instance from a `StanModel` instance and a seed.
This can be used in the `param_constrain` and `param_constrain!` methods
when using the generated quantities block.
@@ -102,24 +106,26 @@ mutable struct StanRNG
lib::Ptr{Nothing}
rng::Ptr{StanRNGStruct}
seed::UInt32
+ chain_id::UInt32
- function StanRNG(sm::StanModel, seed)
- seed = convert(UInt32, seed)
-
+ function StanRNG(sm::StanModel, chain_id; seed=nothing)
+ seed = convert(UInt32, if isnothing(seed) sm.seed else seed end)
+ chain_id = convert(UInt32, chain_id)
err = Ref{Cstring}()
rng = ccall(
Libc.Libdl.dlsym(sm.lib, "bs_construct_rng"),
Ptr{StanModelStruct},
- (UInt32, Ref{Cstring}),
+ (UInt32, UInt32, Ref{Cstring}),
seed,
+ chain_id,
err,
)
if rng == C_NULL
error(_handle_error(sm.lib, err, "bs_construct_rng"))
end
- stanrng = new(sm.lib, rng, seed)
+ stanrng = new(sm.lib, rng, seed, chain_id)
function f(stanrng)
ccall(
@@ -250,13 +256,15 @@ function param_unc_names(sm::StanModel)
end
"""
- param_constrain!(sm, theta_unc, out; include_tp=false, include_gq=false, seed=nothing, rng=nothing)
+ param_constrain!(sm, theta_unc, out; include_tp=false, include_gq=false, chain_id=nothing, rng=nothing)
Returns a vector constrained parameters given unconstrained parameters.
Additionally (if `include_tp` and `include_gq` are set, respectively)
returns transformed parameters and generated quantities.
-If `include_gq` is set, then either `seed` or `rng` must be provided.
+If `include_gq` is set, then either `chain_id` or `rng` must be provided.
+`chain_id` specifies an offset in a PRNG seeded with the model's
+seed which should be unique between calls.
See `StanRNG` for details on how to construct persistent RNGs.
The result is stored in the vector `out`, and a reference is returned. See
@@ -270,7 +278,7 @@ function param_constrain!(
out::Vector{Float64};
include_tp = false,
include_gq = false,
- seed::Union{Int, Nothing} = nothing,
+ chain_id::Union{Int, Nothing} = nothing,
rng::Union{StanRNG, Nothing} = nothing,
)
dims = param_num(sm; include_tp = include_tp, include_gq = include_gq)
@@ -280,15 +288,15 @@ function param_constrain!(
)
end
- if seed === nothing && rng === nothing
+ if chain_id === nothing && rng === nothing
if include_gq
throw(
ArgumentError(
- "Must provide either a seed or an RNG when including generated quantities",
+ "Must provide either a chain_id or an RNG when including generated quantities",
),
)
else
- seed = 0
+ chain_id = 0
end
end
@@ -309,9 +317,9 @@ function param_constrain!(
err,
)
else
- seed = convert(UInt32, seed)
+ chain_id = convert(UInt32, chain_id)
rc = ccall(
- Libc.Libdl.dlsym(sm.lib, "bs_param_constrain_seed"),
+ Libc.Libdl.dlsym(sm.lib, "bs_param_constrain_id"),
Cint,
(Ptr{StanModelStruct}, Cint, Cint, Ref{Cdouble}, Ref{Cdouble}, Cuint, Ref{Cstring}),
sm.stanmodel,
@@ -319,7 +327,7 @@ function param_constrain!(
include_gq,
theta_unc,
out,
- seed,
+ chain_id,
err,
)
end
@@ -330,13 +338,15 @@ function param_constrain!(
end
"""
- param_constrain(sm, theta_unc, out; include_tp=false, include_gq=false, seed=nothing, rng=nothing)
+ param_constrain(sm, theta_unc, out; include_tp=false, include_gq=false, chain_id=nothing, rng=nothing)
Returns a vector constrained parameters given unconstrained parameters.
Additionally (if `include_tp` and `include_gq` are set, respectively)
returns transformed parameters and generated quantities.
-If `include_gq` is set, then either `seed` or `rng` must be provided.
+If `include_gq` is set, then either `chain_id` or `rng` must be provided.
+`chain_id` specifies an offset in a PRNG seeded with the model's
+seed which should be unique between calls.
See `StanRNG` for details on how to construct persistent RNGs.
This allocates new memory for the output each call.
@@ -350,11 +360,11 @@ function param_constrain(
theta_unc::Vector{Float64};
include_tp = false,
include_gq = false,
- seed::Union{Int, Nothing} = nothing,
+ chain_id::Union{Int, Nothing} = nothing,
rng::Union{StanRNG, Nothing} = nothing,
)
out = zeros(param_num(sm, include_tp = include_tp, include_gq = include_gq))
- param_constrain!(sm, theta_unc, out; include_tp = include_tp, include_gq = include_gq, seed=seed, rng=rng)
+ param_constrain!(sm, theta_unc, out; include_tp = include_tp, include_gq = include_gq, chain_id=chain_id, rng=rng)
end
"""
diff --git a/julia/test/model_tests.jl b/julia/test/model_tests.jl
index feed1660..48248ac0 100644
--- a/julia/test/model_tests.jl
+++ b/julia/test/model_tests.jl
@@ -176,7 +176,7 @@ end
model2 = load_test_model("full", false)
- rng = StanRNG(model2, 1234)
+ rng = StanRNG(model2, 2; seed=1234)
@test 1 == length(BridgeStan.param_constrain(model2, a))
@test 2 == length(BridgeStan.param_constrain(model2, a; include_tp = true))
@test 3 == length(BridgeStan.param_constrain(model2, a; include_gq = true, rng=rng))
@@ -185,8 +185,8 @@ end
)
# reproducibility
- @test isapprox(BridgeStan.param_constrain(model2, a; include_gq = true, seed=1234),
- BridgeStan.param_constrain(model2, a; include_gq = true, seed=1234))
+ @test isapprox(BridgeStan.param_constrain(model2, a; include_gq = true, chain_id=3),
+ BridgeStan.param_constrain(model2, a; include_gq = true, chain_id=3))
# no seed or rng provided
@test_throws ArgumentError BridgeStan.param_constrain(model2, a; include_gq = true)
@@ -207,7 +207,7 @@ end
model4,
y;
include_gq = true,
- seed=1234
+ chain_id=1,
)
end
diff --git a/python/bridgestan/model.py b/python/bridgestan/model.py
index a7457927..bd77c725 100644
--- a/python/bridgestan/model.py
+++ b/python/bridgestan/model.py
@@ -45,8 +45,7 @@ def __init__(
:param model_lib: A system path to compiled shared object.
:param model_data: Either a JSON string literal or a
system path to a data file in JSON format ending in ``.json``.
- :param seed: A pseudo random number generator seed. This is only used
- if the model has RNG usage in the ``transformed data`` block.
+ :param seed: A pseudo random number generator seed.
:raises FileNotFoundError or PermissionError: If ``model_lib`` is not readable or
``model_data`` is specified and not a path to a readable file.
:raises RuntimeError: If there is an error instantiating the
@@ -118,9 +117,9 @@ def __init__(
star_star_char,
]
- self._param_constrain_seed = self.stanlib.bs_param_constrain_seed
- self._param_constrain_seed.restype = ctypes.c_int
- self._param_constrain_seed.argtypes = [
+ self._param_constrain_id = self.stanlib.bs_param_constrain_id
+ self._param_constrain_id.restype = ctypes.c_int
+ self._param_constrain_id.argtypes = [
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
@@ -230,7 +229,7 @@ def __del__(self) -> None:
def __repr__(self) -> str:
data = f"{self.data_path!r}, " if self.data_path else ""
- return f"StanModel({self.lib_path!r}, {data}seed={self.seed})"
+ return f"StanModel({self.lib_path!r}, {data}, seed={self.seed})"
def name(self) -> str:
"""
@@ -309,7 +308,7 @@ def param_constrain(
include_tp: bool = False,
include_gq: bool = False,
out: Optional[FloatArray] = None,
- seed: Optional[int] = None,
+ chain_id: Optional[int] = None,
rng: Optional["StanRNG"] = None,
) -> FloatArray:
"""
@@ -326,28 +325,28 @@ def param_constrain(
provided, it must have shape `(D, )`, where `D` is the number of
constrained parameters. If not provided or `None`, a freshly
allocated array is returned.
- :param seed: A pseudo random number generator seed. One of
- ``rng`` or ``seed`` must be specified if ``include_gq``
- is ``True``.
+ :param chain_id: A chain ID used to offset a PRNG seeded with the model's
+ seed which should be unique between calls. One of ``rng`` or
+ ``chain_id`` must be specified if ``include_gq`` is ``True``.
:param rng: A ``StanRNG`` object to use for generating random
numbers, see :meth:`~StanModel.new_rng`.
- One of ``rng`` or ``seed`` must be specified if ``include_gq``
+ One of ``rng`` or ``chain_id`` must be specified if ``include_gq``
is ``True``.
:return: The constrained parameter array.
:raises ValueError: If ``out`` is specified and is not the same
shape as the return.
- :raises ValueError: If neither ``rng`` nor ``seed`` is specified
+ :raises ValueError: If neither ``rng`` nor ``chain_id`` is specified
and ``include_gq`` is ``True``.
:raises RuntimeError: If the C++ Stan model throws an exception.
"""
- if seed is None and rng is None:
+ if chain_id is None and rng is None:
if include_gq:
raise ValueError(
- "Error: must specify rng or seed when including generated quantities"
+ "Error: must specify rng or chain_id when including generated quantities"
)
else:
- # neither specified, but not doing gq, so use a fixed seed
- seed = 0
+ # neither specified, but not doing gq, so use a fixed chain id
+ chain_id = 0
dims = self.param_num(include_tp=include_tp, include_gq=include_gq)
if out is None:
@@ -369,13 +368,13 @@ def param_constrain(
err,
)
else:
- rc = self._param_constrain_seed(
+ rc = self._param_constrain_id(
self.model_rng,
int(include_tp),
int(include_gq),
theta_unc,
out,
- seed,
+ chain_id,
err,
)
@@ -383,14 +382,18 @@ def param_constrain(
raise self._handle_error(err.contents, "param_constrain")
return out
- def new_rng(self, seed: int) -> "StanRNG":
+ def new_rng(self, chain_id: int, *, seed=None) -> "StanRNG":
"""
Return a new PRNG for use in :meth:`~StanModel.param_constrain`.
- :param seed: The seed for the PRNG.
+ :param chain_id: Identifier for a sequence in the RNG. This should be made
+ a distinct number for each PRNG created with the same seed
+ (for example, 1:N for N PRNGS).
+ :param seed: A seed for the PRNG. If not specified, the
+ model's seed is used.
:return: A new PRNG for the model.
"""
- return StanRNG(self.stanlib, seed)
+ return StanRNG(self.stanlib, seed or self.seed, chain_id)
def param_unconstrain(
self, theta: FloatArray, *, out: Optional[FloatArray] = None
@@ -606,13 +609,13 @@ def _handle_error(self, err: ctypes.c_char_p, method: str) -> Exception:
class StanRNG:
- def __init__(self, lib: ctypes.CDLL, seed: int) -> None:
+ def __init__(self, lib: ctypes.CDLL, seed: int, chain_id: int) -> None:
self.stanlib = lib
construct = self.stanlib.bs_construct_rng
construct.restype = ctypes.c_void_p
- construct.argtypes = [ctypes.c_uint, star_star_char]
- self.ptr = construct(seed, None)
+ construct.argtypes = [ctypes.c_uint, ctypes.c_uint, star_star_char]
+ self.ptr = construct(seed, chain_id, None)
if not self.ptr:
raise RuntimeError("Failed to construct RNG.")
diff --git a/python/test/test_stanmodel.py b/python/test/test_stanmodel.py
index 32914d77..4d319ac1 100644
--- a/python/test/test_stanmodel.py
+++ b/python/test/test_stanmodel.py
@@ -208,7 +208,7 @@ def test_param_constrain():
full_so = str(STAN_FOLDER / "full" / "full_model.so")
bridge2 = bs.StanModel(full_so)
- rng = bridge2.new_rng(seed=1234)
+ rng = bridge2.new_rng(chain_id=1, seed=1234)
np.testing.assert_equal(1, bridge2.param_constrain(a).size)
np.testing.assert_equal(2, bridge2.param_constrain(a, include_tp=True).size)
@@ -221,8 +221,8 @@ def test_param_constrain():
# reproducibility test
np.testing.assert_equal(
- bridge2.param_constrain(a, include_gq=True, seed=1234),
- bridge2.param_constrain(a, include_gq=True, seed=1234),
+ bridge2.param_constrain(a, include_gq=True, chain_id=4),
+ bridge2.param_constrain(a, include_gq=True, chain_id=4),
)
# test error if neither seed or rng is provided
@@ -251,7 +251,7 @@ def test_param_constrain():
bridge3 = bs.StanModel(throw_gq_so)
bridge3.param_constrain(y, include_gq=False)
with pytest.raises(RuntimeError, match="find this text: gqfails"):
- bridge3.param_constrain(y, include_gq=True, seed=123)
+ bridge3.param_constrain(y, include_gq=True, chain_id=1)
def test_param_unconstrain():
diff --git a/src/bridgestan.cpp b/src/bridgestan.cpp
index 18cc76ff..2d590e2c 100644
--- a/src/bridgestan.cpp
+++ b/src/bridgestan.cpp
@@ -85,11 +85,11 @@ int bs_param_constrain(const bs_model* mr, bool include_tp, bool include_gq,
return 1;
}
-int bs_param_constrain_seed(const bs_model* mr, bool include_tp,
+int bs_param_constrain_id(const bs_model* mr, bool include_tp,
bool include_gq, const double* theta_unc,
- double* theta, unsigned int seed,
+ double* theta, unsigned int chain_id,
char** error_msg) {
- bs_rng rng(seed);
+ bs_rng rng(mr->seed(), chain_id);
return bs_param_constrain(mr, include_tp, include_gq, theta_unc, theta, &rng,
error_msg);
}
@@ -208,9 +208,9 @@ int bs_log_density_hessian(const bs_model* mr, bool propto, bool jacobian,
return -1;
}
-bs_rng* bs_construct_rng(unsigned int seed, char** error_msg) {
+bs_rng* bs_construct_rng(unsigned int seed, unsigned int chain_id, char** error_msg) {
try {
- return new bs_rng(seed);
+ return new bs_rng(seed, chain_id);
} catch (const std::exception& e) {
if (error_msg) {
std::stringstream error;
diff --git a/src/bridgestan.h b/src/bridgestan.h
index 39cae4e5..16cee57f 100644
--- a/src/bridgestan.h
+++ b/src/bridgestan.h
@@ -23,7 +23,8 @@ typedef int bool;
* @return pointer to constructed model or `nullptr` if construction
* fails
*/
-bs_model* bs_construct(const char* data_file, unsigned int seed, char** error_msg);
+bs_model* bs_construct(const char* data_file, unsigned int seed,
+ char** error_msg);
/**
* Destroy the model and return 0 for success and -1 if there is an
@@ -163,25 +164,26 @@ int bs_param_constrain(const bs_model* mr, bool include_tp, bool include_gq,
* in the Stan program, with multivariate parameters given in
* last-index-major order.
*
- * This version accepts a seed which is used to create a fresh PRNG
- * which lives only for the duration of this call.
+ * This version accepts a chain_id which is used to create a PRNG
+ * offset from the model's seed which lives only for the duration
+ * of this call.
*
* @param[in] mr pointer to model and RNG structure
* @param[in] include_tp `true` to include transformed parameters
* @param[in] include_gq `true` to include generated quantities
* @param[in] theta_unc sequence of unconstrained parameters
* @param[out] theta sequence of constrained parameters
- * @param[in] seed seed for pseudorandom number generator which will be created
- * and destroyed during this call. See `bs_param_constrain` for an option with a
- * persistent RNG.
+ * @param[in] chain_id offset for pseudorandom number generator which will be
+ * created and destroyed during this call. seeded with model seed. See
+ * `bs_param_constrain` for an option with a persistent RNG.
* @param[out] error_msg a pointer to a string that will be allocated if there
* is an error. This must later be freed by calling `bs_free_error_msg`.
* @return code 0 if successful and code -1 if there is an exception
* in the underlying Stan code
*/
-int bs_param_constrain_seed(const bs_model* mr, bool include_tp,
- bool include_gq, const double* theta_unc,
- double* theta, unsigned int seed, char** error_msg);
+int bs_param_constrain_id(const bs_model* mr, bool include_tp, bool include_gq,
+ const double* theta_unc, double* theta,
+ unsigned int chain_id, char** error_msg);
/**
* Set the sequence of unconstrained parameters based on the
@@ -294,15 +296,17 @@ int bs_log_density_hessian(const bs_model* mr, bool propto, bool jacobian,
double* hessian, char** error_msg);
/**
- * Construct an RNG object to be used in `bs_param_constrain`.
+ * Construct an PRNG object to be used in `bs_param_constrain`.
* This object is not thread safe and should be constructed and
* destructed for each thread.
*
* @param[in] seed seed for the RNG
+ * @param[in] chain_id identifier for the current sequence of PRNG draws
* @param[out] error_msg a pointer to a string that will be allocated if there
* is an error. This must later be freed by calling `bs_free_error_msg`.
*/
-bs_rng* bs_construct_rng(unsigned int seed, char** error_msg);
+bs_rng* bs_construct_rng(unsigned int seed, unsigned int chain_id,
+ char** error_msg);
/**
* Destruct an RNG object.
diff --git a/src/bridgestanR.cpp b/src/bridgestanR.cpp
index e1e0f246..cc5cdc88 100644
--- a/src/bridgestanR.cpp
+++ b/src/bridgestanR.cpp
@@ -39,12 +39,12 @@ void bs_param_constrain_R(bs_model** model, int* include_tp, int* include_gq,
theta, *rng, err_msg);
*err_ptr = static_cast(*err_msg);
}
-void bs_param_constrain_seed_R(bs_model** model, int* include_tp,
+void bs_param_constrain_id_R(bs_model** model, int* include_tp,
int* include_gq, const double* theta_unc,
- double* theta, int* seed, int* return_code,
+ double* theta, int* chain_id, int* return_code,
char** err_msg, void** err_ptr) {
- *return_code = bs_param_constrain_seed(*model, *include_tp, *include_gq,
- theta_unc, theta, *seed, err_msg);
+ *return_code = bs_param_constrain_id(*model, *include_tp, *include_gq,
+ theta_unc, theta, *chain_id, err_msg);
*err_ptr = static_cast(*err_msg);
}
void bs_param_unconstrain_R(bs_model** model, const double* theta,
@@ -82,9 +82,9 @@ void bs_log_density_hessian_R(bs_model** model, int* propto, int* jacobian,
val, grad, hess, err_msg);
*err_ptr = static_cast(*err_msg);
}
-void bs_construct_rng_R(int* seed, bs_rng** ptr_out, char** err_msg,
+void bs_construct_rng_R(int* seed, int* chain_id, bs_rng** ptr_out, char** err_msg,
void** err_ptr) {
- *ptr_out = bs_construct_rng(*seed, err_msg);
+ *ptr_out = bs_construct_rng(*seed, *chain_id, err_msg);
*err_ptr = static_cast(*err_msg);
}
void bs_destruct_rng_R(bs_rng** rng, int* return_code) {
diff --git a/src/bridgestanR.h b/src/bridgestanR.h
index 811580b5..f14f4481 100644
--- a/src/bridgestanR.h
+++ b/src/bridgestanR.h
@@ -41,9 +41,9 @@ void bs_param_constrain_R(bs_model** model, int* include_tp, int* include_gq,
const double* theta_unc, double* theta, bs_rng** rng,
int* return_code, char** err_msg, void** err_ptr);
-void bs_param_constrain_seed_R(bs_model** model, int* include_tp,
+void bs_param_constrain_id_R(bs_model** model, int* include_tp,
int* include_gq, const double* theta_unc,
- double* theta, int* seed, int* return_code,
+ double* theta, int* chain_id, int* return_code,
char** err_msg, void** err_ptr);
void bs_param_unconstrain_R(bs_model** model, const double* theta,
@@ -68,7 +68,7 @@ void bs_log_density_hessian_R(bs_model** model, int* propto, int* jacobian,
double* grad, double* hess, int* return_code,
char** err_msg, void** err_ptr);
-void bs_construct_rng_R(int* seed, bs_rng** ptr_out, char** err_msg,
+void bs_construct_rng_R(int* seed, int* chain_id, bs_rng** ptr_out, char** err_msg,
void** err_ptr);
void bs_destruct_rng_R(bs_rng** rng, int* return_code);
diff --git a/src/model_rng.cpp b/src/model_rng.cpp
index 3955cadb..887ae6c5 100644
--- a/src/model_rng.cpp
+++ b/src/model_rng.cpp
@@ -78,6 +78,8 @@ bs_model::bs_model(const char* data_file, unsigned int seed) {
}
}
+ seed_ = seed;
+
std::string model_name = model_->model_name();
const char* model_name_c = model_name.c_str();
name_ = strdup(model_name_c);
@@ -167,6 +169,8 @@ bs_model::~bs_model() {
const char* bs_model::name() const { return name_; }
+unsigned int bs_model::seed() const { return seed_; }
+
const char* bs_model::model_info() const { return model_info_; }
const char* bs_model::param_names(bool include_tp, bool include_gq) const {
@@ -323,3 +327,7 @@ void bs_model::log_density_hessian(bool propto, bool jacobian,
Eigen::VectorXd::Map(grad, N) = grad_vec;
Eigen::MatrixXd::Map(hessian, N, N) = hess_mat;
}
+
+bs_rng::bs_rng(unsigned int seed, unsigned int chain_id) {
+ rng_ = stan::services::util::create_rng(seed, chain_id + 1);
+}
diff --git a/src/model_rng.hpp b/src/model_rng.hpp
index d83a4f2d..031e2dfd 100644
--- a/src/model_rng.hpp
+++ b/src/model_rng.hpp
@@ -38,6 +38,14 @@ class bs_model {
*/
const char* name() const;
+ /**
+ * Return the pseudorandom number generator seed
+ * used during model construction.
+ *
+ * @return seed
+ */
+ unsigned int seed() const;
+
/**
* Return information about the compiled model. This class manages the
* memory, so the returned string should not be freed.
@@ -182,6 +190,9 @@ class bs_model {
/** Stan model */
stan::model::model_base* model_;
+ /** RNG seed provided during model creation */
+ unsigned int seed_;
+
/** name of the Stan model */
char* name_ = nullptr;
@@ -230,7 +241,7 @@ class bs_model {
*/
class bs_rng {
public:
- bs_rng(unsigned int seed) : rng_(seed) { rng_.discard(1); }
+ bs_rng(unsigned int seed, unsigned int chain_id);
boost::ecuyer1988 rng_;
};