Skip to content

Commit

Permalink
Merge pull request #350 from stan-dev/mpi
Browse files Browse the repository at this point in the history
MPI sampling
  • Loading branch information
rok-cesnovar authored Dec 4, 2020
2 parents f24b5e0 + 6e1c24b commit 291ceb3
Show file tree
Hide file tree
Showing 14 changed files with 404 additions and 8 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ jobs:
mingw32-make --version
Get-Command mingw32-make | Select-Object -ExpandProperty Definition
shell: powershell

- name: Install MPI
if: runner.os == 'Linux'
run: |
sudo apt-get install -y openmpi-bin
echo "CMDSTANR_RUN_MPI_TESTS=TRUE" >> $GITHUB_ENV
- uses: r-lib/actions/setup-r@master
with:
Expand Down
6 changes: 5 additions & 1 deletion .github/workflows/Test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ jobs:
- uses: r-lib/actions/setup-pandoc@master

- name: Install Ubuntu dependencies
run: sudo apt-get install libcurl4-openssl-dev
run: |
sudo apt-get install libcurl4-openssl-dev
sudo apt-get install -y openmpi-bin
echo "CMDSTANR_RUN_MPI_TESTS=TRUE" >> $GITHUB_ENV
- name: Query dependencies
run: |
install.packages('remotes')
Expand Down
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# Items for next tagged release

### Bug fixes

### New features

* Added `$sample_mpi()` for MCMC sampling with MPI. (#350)

# cmdstanr 0.2.2

### Bug fixes
Expand Down
175 changes: 175 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ cmdstan_model <- function(stan_file, compile = TRUE, ...) {
#' |**Method**|**Description**|
#' |:----------|:---------------|
#' [`$sample()`][model-method-sample] | Run CmdStan's `"sample"` method, return [`CmdStanMCMC`] object. |
#' [`$sample_mpi()`][model-method-sample_mpi] | Run CmdStan's `"sample"` method with [MPI](https://mc-stan.org/math/mpi.html), return [`CmdStanMCMC`] object. |
#' [`$optimize()`][model-method-optimize] | Run CmdStan's `"optimize"` method, return [`CmdStanMLE`] object. |
#' [`$variational()`][model-method-variational] | Run CmdStan's `"variational"` method, return [`CmdStanVB`] object. |
#' [`$generate_quantities()`][model-method-generate-quantities] | Run CmdStan's `"generate quantities"` method, return [`CmdStanGQ`] object. |
Expand Down Expand Up @@ -964,6 +965,180 @@ sample_method <- function(data = NULL,
}
CmdStanModel$set("public", name = "sample", value = sample_method)

#' Run Stan's MCMC algorithms with MPI
#'
#' @name model-method-sample_mpi
#' @aliases sample_mpi
#' @family CmdStanModel methods
#'
#' @description The `$sample_mpi()` method of a [`CmdStanModel`] object is
#' identical to the `$sample()` method but with support for
#' [MPI](https://mc-stan.org/math/mpi.html). The target audience for MPI are
#' those with large computer clusters. For other users, the
#' [`$sample()`][model-method-sample] method provides both parallelization of
#' chains and threading support for within-chain parallelization.
#'
#' @details In order to use MPI with Stan, an MPI implementation must be
#' installed. For Unix systems the most commonly used implementations are
#' MPICH and OpenMPI. The implementations provide an MPI C++ compiler wrapper
#' (for example mpicxx), which is required to compile the model.
#'
#' An example of compiling with MPI:
#' ```
#' mpi_options <- list(STAN_MPI=TRUE, CXX="mpicxx", TBB_CXX_TYPE="gcc")
#' mod <- cmdstan_model("model.stan", cpp_options = mpi_options)
#' ```
#' The C++ options that must be supplied to the
#' [compile][model-method-compile] call are:
#' - `STAN_MPI`: Enables the use of MPI with Stan if `TRUE`.
#' - `CXX`: The name of the MPI C++ compiler wrapper. Typically `"mpicxx"`.
#' - `TBB_CXX_TYPE`: The C++ compiler the MPI wrapper wraps. Typically `"gcc"`
#' on Linux and `"clang"` on macOS.
#'
#' In the call to the `$sample_mpi()` method we can also provide the name of
#' the MPI launcher (`mpi_cmd`, defaulting to `"mpiexec"`) and any other
#' MPI launch arguments. In most cases, it is enough to only define the number
#' of processes with `mpi_args = list("n" = 4)`.
#'
#' @section Usage:
#' ```
#' $sample_mpi(
#' data = NULL,
#' mpi_cmd = "mpiexec",
#' mpi_args = NULL,
#' seed = NULL,
#' refresh = NULL,
#' init = NULL,
#' save_latent_dynamics = FALSE,
#' output_dir = NULL,
#' chains = 4,
#' parallel_chains = getOption("mc.cores", 1),
#' chain_ids = seq_len(chains),
#' iter_warmup = NULL,
#' iter_sampling = NULL,
#' save_warmup = FALSE,
#' thin = NULL,
#' max_treedepth = NULL,
#' adapt_engaged = TRUE,
#' adapt_delta = NULL,
#' step_size = NULL,
#' metric = NULL,
#' metric_file = NULL,
#' inv_metric = NULL,
#' init_buffer = NULL,
#' term_buffer = NULL,
#' window = NULL,
#' fixed_param = FALSE,
#' sig_figs = NULL,
#' validate_csv = TRUE,
#' show_messages = TRUE
#' )
#' ```
#'
#' @section Arguments unique to the `sample_mpi` method:
#' * `mpi_cmd`: (character vector) The MPI launcher used for launching MPI processes.
#' The default launcher is `"mpiexec"`.
#' * `mpi_args`: (list) A list of arguments to use when launching MPI processes.
#' For example, `mpi_args = list("n" = 4)` launches the executable as
#' `mpiexec -n 4 model_executable`, followed by CmdStan arguments
#' for the model executable.
#'
#' All other arguments are the same as for [`$sample()`][model-method-sample]
#' except `$sample_mpi()` does not have arguments `threads_per_chain` or
#' `parallel_chains`.
#'
#' @section Value: The `$sample_mpi()` method returns a [`CmdStanMCMC`] object.
#'
#' @template seealso-docs
#' @seealso The Stan Math Library's MPI documentation
#' ([mc-stan.org/math/mpi](https://mc-stan.org/math/mpi.html)) for more
#' details on MPI support in Stan.
#'
#' @examples
#' \dontrun{
#' # mpi_options <- list(STAN_MPI=TRUE, CXX="mpicxx", TBB_CXX_TYPE="gcc")
#' # mod <- cmdstan_model("model.stan", cpp_options = mpi_options)
#' # fit <- mod$sample_mpi(..., mpi_args = list("n" = 4))
#' }
#'
NULL

sample_mpi_method <- function(data = NULL,
mpi_cmd = "mpiexec",
mpi_args = NULL,
seed = NULL,
refresh = NULL,
init = NULL,
save_latent_dynamics = FALSE,
output_dir = NULL,
chains = 1,
chain_ids = seq_len(chains),
iter_warmup = NULL,
iter_sampling = NULL,
save_warmup = FALSE,
thin = NULL,
max_treedepth = NULL,
adapt_engaged = TRUE,
adapt_delta = NULL,
step_size = NULL,
metric = NULL,
metric_file = NULL,
inv_metric = NULL,
init_buffer = NULL,
term_buffer = NULL,
window = NULL,
fixed_param = FALSE,
sig_figs = NULL,
validate_csv = TRUE,
show_messages = TRUE) {
if (fixed_param) {
chains <- 1
save_warmup <- FALSE
}

checkmate::assert_integerish(chains, lower = 1, len = 1)
checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE)
sample_args <- SampleArgs$new(
iter_warmup = iter_warmup,
iter_sampling = iter_sampling,
save_warmup = save_warmup,
thin = thin,
max_treedepth = max_treedepth,
adapt_engaged = adapt_engaged,
adapt_delta = adapt_delta,
step_size = step_size,
metric = metric,
metric_file = metric_file,
inv_metric = inv_metric,
init_buffer = init_buffer,
term_buffer = term_buffer,
window = window,
fixed_param = fixed_param
)
cmdstan_args <- CmdStanArgs$new(
method_args = sample_args,
model_name = strip_ext(basename(self$exe_file())),
exe_file = self$exe_file(),
proc_ids = chain_ids,
data_file = process_data(data),
save_latent_dynamics = save_latent_dynamics,
seed = seed,
init = init,
refresh = refresh,
output_dir = output_dir,
validate_csv = validate_csv,
sig_figs = sig_figs
)
cmdstan_procs <- CmdStanMCMCProcs$new(
num_procs = chains,
parallel_procs = 1,
show_stderr_messages = show_messages
)
runset <- CmdStanRun$new(args = cmdstan_args, procs = cmdstan_procs)
runset$run_cmdstan_mpi(mpi_cmd, mpi_args)
CmdStanMCMC$new(runset)
}
CmdStanModel$set("public", name = "sample_mpi", value = sample_mpi_method)

#' Run Stan's optimization algorithms
#'
Expand Down
46 changes: 39 additions & 7 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ CmdStanRun <- R6::R6Class(
}
invisible(self)
},

num_procs = function() self$procs$num_procs(),
proc_ids = function() self$procs$proc_ids(),
exe_file = function() self$args$exe_file,
Expand Down Expand Up @@ -150,6 +149,10 @@ CmdStanRun <- R6::R6Class(
}
},

run_cmdstan_mpi = function(mpi_cmd, mpi_args) {
private$run_sample_(mpi_cmd, mpi_args)
},

# run bin/stansummary or bin/diagnose
# @param tool The name of the tool in `bin/` to run.
# @param flags An optional character vector of flags (e.g. c("--sig_figs=1")).
Expand Down Expand Up @@ -221,10 +224,15 @@ CmdStanRun <- R6::R6Class(


# run helpers -------------------------------------------------
.run_sample <- function() {
.run_sample <- function(mpi_cmd = NULL, mpi_args = NULL) {
procs <- self$procs
on.exit(procs$cleanup(), add = TRUE)

if (!is.null(mpi_cmd)) {
if (is.null(mpi_args)) {
mpi_args = list()
}
mpi_args[["exe"]] <- self$exe_file()
}
# add path to the TBB library to the PATH variable
if (cmdstan_version() >= "2.21" && os_is_windows()) {
path_to_TBB <- file.path(cmdstan_path(), "stan", "lib", "stan_math", "lib", "tbb")
Expand All @@ -239,7 +247,20 @@ CmdStanRun <- R6::R6Class(
start_msg <- paste0("Running MCMC with ", procs$num_procs(), " parallel chains")
} else {
if (procs$parallel_procs() == 1) {
start_msg <- paste0("Running MCMC with ", procs$num_procs(), " sequential chains")
if (!is.null(mpi_cmd)) {
if (!is.null(mpi_args[["n"]])) {
mpi_n_process <- mpi_args[["n"]]
} else if (!is.null(mpi_args[["np"]])) {
mpi_n_process <- mpi_args[["np"]]
}
if (is.null(mpi_n_process)) {
start_msg <- paste0("Running MCMC with ", procs$num_procs(), " chains using MPI")
} else {
start_msg <- paste0("Running MCMC with ", procs$num_procs(), " chains using MPI with ", mpi_n_process, " processes")
}
} else {
start_msg <- paste0("Running MCMC with ", procs$num_procs(), " sequential chains")
}
} else {
start_msg <- paste0("Running MCMC with ", procs$num_procs(), " chains, at most ", procs$parallel_procs(), " in parallel")
}
Expand All @@ -260,7 +281,9 @@ CmdStanRun <- R6::R6Class(
id = chain_id,
command = self$command(),
args = self$command_args()[[chain_id]],
wd = dirname(self$exe_file())
wd = dirname(self$exe_file()),
mpi_cmd = mpi_cmd,
mpi_args = mpi_args
)
procs$mark_proc_start(chain_id)
procs$set_active_procs(procs$active_procs() + 1)
Expand Down Expand Up @@ -477,12 +500,21 @@ CmdStanProcs <- R6::R6Class(
get_proc = function(id) {
private$processes_[[id]]
},
new_proc = function(id, command, args, wd) {
new_proc = function(id, command, args, wd, mpi_cmd = NULL, mpi_args = NULL) {
if (!is.null(mpi_cmd)) {
exe_name <- mpi_args[["exe"]]
mpi_args[["exe"]] <- NULL
mpi_args_vector <- c()
for (i in names(mpi_args)) {
mpi_args_vector <- c(paste0("-", i), mpi_args[[i]], mpi_args_vector)
}
args = c(mpi_args_vector, exe_name, args)
command <- mpi_cmd
}
private$processes_[[id]] <- processx::process$new(
command = command,
args = args,
wd = wd,
echo_cmd = FALSE,
stdout = "|",
stderr = "|"
)
Expand Down
1 change: 1 addition & 0 deletions man/CmdStanModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-check_syntax.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-compile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-generate-quantities.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 291ceb3

Please sign in to comment.