-
-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPI sampling #350
MPI sampling #350
Changes from 15 commits
45df936
0b4b42a
f050b47
00d2ddc
a885a63
e2df165
393efac
61412cf
75c6dc7
c1e8207
8175c5a
8a35242
1a6d40d
6bd2b4e
918b66c
8b1ff38
e51d077
cf33b34
6be7a31
500bd6d
3228380
86e71d7
9121292
d46b8d7
3ffd517
641d0c3
6db11dc
234f1d1
cf22f84
6e1c24b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -914,6 +914,180 @@ sample_method <- function(data = NULL, | |
} | ||
CmdStanModel$set("public", name = "sample", value = sample_method) | ||
|
||
#' Run Stan's MCMC algorithms with MPI | ||
#' | ||
#' @name model-method-mpi-sample | ||
#' @aliases mpi_sample | ||
#' @family CmdStanModel methods | ||
#' | ||
#' @description The `$mpi_sample()` method of a [`CmdStanModel`] object runs the | ||
#' default MCMC algorithm in CmdStan (`algorithm=hmc engine=nuts`) with MPI | ||
#' (STAN_MPI makefile flag), to produce a set of draws from the posterior | ||
#' distribution of a model conditioned on some data. | ||
#' | ||
#' In order to use MPI with Stan, an MPI implementation must be installed. | ||
#' For Unix systems the most commonly used implementations are MPICH and OpenMPI. | ||
#' The implementations provide an MPI C++ compiler wrapper (for example mpicxx), | ||
#' which is required to compile the model. | ||
#' | ||
#' An example of compiling with STAN_MPI: | ||
#' ``` | ||
#' cpp_options = list(STAN_MPI = TRUE, CXX="mpicxx", TBB_CXX_TYPE="gcc") | ||
#' mod <- cmdstan_model("model.stan", cpp_options = cpp_options) | ||
#' ``` | ||
#' The C++ options that need supplied to the compile call are: | ||
#' - `STAN_MPI`: Enables the use of MPI with Stan | ||
#' - `CXX`: The name of the MPI C++ compiler wrapper (typicall mpicxx) | ||
#' - `TBB_CXX_TYPE`: The C++ compiler the MPI wrapper wraps. Typically gcc on | ||
#' Linux and clang on macOS. | ||
#' | ||
#' In the call to the `$mpi_sample()` method, we can additionally provide | ||
#' the name of the MPI launcher (`mpi_cmd`), which defaults to "mpiexec", | ||
#' and any other MPI launch arguments. In most cases, it is enough to | ||
#' only define the number of processes with `mpi_args = list("n" = 4)`. | ||
#' | ||
#' An example of a call of `$mpi_sample()`: | ||
#' ``` | ||
#' cpp_options = list(STAN_MPI = TRUE, CXX="mpicxx", TBB_CXX_TYPE="gcc") | ||
#' fit <- mod$mpi_sample(data_list, mpi_args = c("-n", 4)) | ||
#' ``` | ||
#' | ||
#' @section Usage: | ||
#' ``` | ||
#' $mpi_sample( | ||
#' data = NULL, | ||
#' mpi_cmd = "mpiexec", | ||
#' mpi_args = NULL, | ||
#' seed = NULL, | ||
#' refresh = NULL, | ||
#' init = NULL, | ||
#' save_latent_dynamics = FALSE, | ||
#' output_dir = NULL, | ||
#' chains = 4, | ||
#' parallel_chains = getOption("mc.cores", 1), | ||
#' chain_ids = seq_len(chains), | ||
#' iter_warmup = NULL, | ||
#' iter_sampling = NULL, | ||
#' save_warmup = FALSE, | ||
#' thin = NULL, | ||
#' max_treedepth = NULL, | ||
#' adapt_engaged = TRUE, | ||
#' adapt_delta = NULL, | ||
#' step_size = NULL, | ||
#' metric = NULL, | ||
#' metric_file = NULL, | ||
#' inv_metric = NULL, | ||
#' init_buffer = NULL, | ||
#' term_buffer = NULL, | ||
#' window = NULL, | ||
#' fixed_param = FALSE, | ||
#' sig_figs = NULL, | ||
#' validate_csv = TRUE, | ||
#' show_messages = TRUE | ||
#' ) | ||
#' ``` | ||
#' | ||
#' @section Arguments: | ||
#' * `mpi_cmd`: (character vector) The MPI launcher used for launching MPI processes. | ||
#' The default launcher is `mpiexec`. | ||
#' * `mpi_args`: (list) A list of arguments to use when launching MPI processes. | ||
#' For example, mpi_args = list("n" = 4) launches the executable as | ||
#' `mpiexec -n 4 model_executable`, followed by CmdStan arguments | ||
#' for the model executable. | ||
#' * `data`, `seed`, `refresh`, `init`, `save_latent_dynamics`, `output_dir`, | ||
#' `chains`, `parallel_chains`, `chain_ids`, `iter_warmup`, `iter_sampling`, | ||
#' `save_warmup`, `thin`, `max_treedepth`, `adapt_engaged`, `adapt_delta`, | ||
#' `step_size`, `metric`, `metric_file`, `inv_metric`, `init_buffer`, | ||
#' `term_buffer`, `window`, `fixed_param`, `sig_figs`, `validate_csv`, | ||
#' `show_messages`: | ||
#' Same as for the [`$sample()`][model-method-sample] method. | ||
#' | ||
#' @section Value: The `$mpi_sample()` method returns a [`CmdStanMCMC`] object. | ||
#' | ||
#' @template seealso-docs | ||
#' @inherit cmdstan_model examples | ||
#' | ||
NULL | ||
mpi_sample_method <- function(data = NULL, | ||
mpi_cmd = "mpiexec", | ||
mpi_args = NULL, | ||
seed = NULL, | ||
refresh = NULL, | ||
init = NULL, | ||
save_latent_dynamics = FALSE, | ||
output_dir = NULL, | ||
chains = 1, | ||
parallel_chains = getOption("mc.cores", 1), | ||
chain_ids = seq_len(chains), | ||
iter_warmup = NULL, | ||
iter_sampling = NULL, | ||
save_warmup = FALSE, | ||
thin = NULL, | ||
max_treedepth = NULL, | ||
adapt_engaged = TRUE, | ||
adapt_delta = NULL, | ||
step_size = NULL, | ||
metric = NULL, | ||
metric_file = NULL, | ||
inv_metric = NULL, | ||
init_buffer = NULL, | ||
term_buffer = NULL, | ||
window = NULL, | ||
fixed_param = FALSE, | ||
sig_figs = NULL, | ||
validate_csv = TRUE, | ||
show_messages = TRUE) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not unusual to use use threading in an MPI process, it's just we haven't done that for Stan. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. If we enable it in Stan (which we should), we can then add threading here also. |
||
|
||
if (fixed_param) { | ||
chains <- 1 | ||
parallel_chains <- 1 | ||
save_warmup <- FALSE | ||
} | ||
|
||
checkmate::assert_integerish(chains, lower = 1, len = 1) | ||
checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE) | ||
checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE) | ||
sample_args <- SampleArgs$new( | ||
iter_warmup = iter_warmup, | ||
iter_sampling = iter_sampling, | ||
save_warmup = save_warmup, | ||
thin = thin, | ||
max_treedepth = max_treedepth, | ||
adapt_engaged = adapt_engaged, | ||
adapt_delta = adapt_delta, | ||
step_size = step_size, | ||
metric = metric, | ||
metric_file = metric_file, | ||
inv_metric = inv_metric, | ||
init_buffer = init_buffer, | ||
term_buffer = term_buffer, | ||
window = window, | ||
fixed_param = fixed_param | ||
) | ||
cmdstan_args <- CmdStanArgs$new( | ||
method_args = sample_args, | ||
model_name = strip_ext(basename(self$exe_file())), | ||
exe_file = self$exe_file(), | ||
proc_ids = chain_ids, | ||
data_file = process_data(data), | ||
save_latent_dynamics = save_latent_dynamics, | ||
seed = seed, | ||
init = init, | ||
refresh = refresh, | ||
output_dir = output_dir, | ||
validate_csv = validate_csv, | ||
sig_figs = sig_figs | ||
) | ||
cmdstan_procs <- CmdStanMCMCProcs$new( | ||
num_procs = chains, | ||
parallel_procs = parallel_chains, | ||
show_stderr_messages = show_messages | ||
) | ||
runset <- CmdStanRun$new(args = cmdstan_args, procs = cmdstan_procs) | ||
runset$run_cmdstan_mpi(mpi_cmd, mpi_args) | ||
CmdStanMCMC$new(runset) | ||
} | ||
CmdStanModel$set("public", name = "mpi_sample", value = mpi_sample_method) | ||
rok-cesnovar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#' Run Stan's optimization algorithms | ||
#' | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default for chains in
$mpi_sample()
is 1, while for$sample()
its 4. Or should we leave it the same as for$sample()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens when
chains=4
in terms of distributing processes? Withmpiexec -n 4
is each chain solved by 1 process?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's
chains = 4, parallel_chains = 1
that is the same as 4 sequential $mpi_sample(chains = 1) and this is just a convenience so the draws are merged together in the fit.If
chains = 4, parallel_chains = 4
that means 4 mpiexec calls with n=4 all running at the same time. Not sure that is useful or if we should just fix parallel_chains to 1.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused, as I thought
mpi_args=c("-n", 4...)
controls the total number of MPI processes. But looks likeparallel_chains=4
impliesmpiexec -n 4
too. Is that right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without MPI, parallel_chains just means that 4 "./model args" calls are made and 4 models processes run in parallel.
So like running this in shell
In
$mpi_sample()
this would mean 4 processes that would runmpiexec -n x ./model args
. Does that make more sense?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find this really confusing, as arg
chains
andparallel_chains
seem to have overlapping meanings. What happens if we removeparalle_chains
and only usechains
inmpi_sample
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comes mainly from non-parallel MCMC sampling. For example users want to run 4 chains but want to only use 2 cores for two chains and use the other two for something else. rstan uses the
cores
argument for this same thing.I think this makes less sense or is less useful in the context of within-chain parallelization with threading or MPI, because if someone uses parallelization its likely they want all the CPU power/cluster power.
We can remove parallel_chains, we just need to decide what to do in the case of chains > 1. Do we run chains sequentially (paralllel_chains = 1) or run all of them at once (parallel_chains = chains).
My gut feeling is that for we go with the former?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My assumption is just that one typically runs a single
mpiexec
with maximum n?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It depends on the context. For now I agree with you semantically it makes more sense to have
chains=4 + mpi_arg("-n", "x")
equivalent toThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, let's think about this a bit more. I am not sure what the best solution would be.