Skip to content

Commit

Permalink
Merge pull request #696 from andrjohns/wsl-cmdstan-internal
Browse files Browse the repository at this point in the history
WSL - Run `cmdstan` and models under WSL filesystem
  • Loading branch information
rok-cesnovar authored Oct 29, 2022
2 parents 53084da + 71c0ea7 commit f8f6321
Show file tree
Hide file tree
Showing 20 changed files with 410 additions and 171 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/R-CMD-check-wsl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ jobs:
run: |
remotes::install_deps(dependencies = TRUE)
remotes::install_cran("rcmdcheck")
remotes::install_local(path = ".")
remotes::install_local(path = ".", INSTALL_opts = "--no-test-load")
install.packages("curl")
shell: Rscript {0}

- uses: Vampire/setup-wsl@v1
with:
distribution: Ubuntu-22.04
use-cache: 'true'
use-cache: 'false'
set-as-default: 'true'
- name: Install WSL Dependencies
run: |
Expand All @@ -74,6 +74,7 @@ jobs:

- name: Install cmdstan
run: |
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(cores = 2, wsl = TRUE, overwrite = TRUE)
shell: Rscript {0}

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ jobs:

- name: Install dependencies
run: |
Sys.setenv("MAKEFLAGS"="-j2")
remotes::install_deps(dependencies = TRUE)
remotes::install_cran("rcmdcheck")
remotes::install_local(path = ".")
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- name: Install dependencies
run: |
install.packages(c("remotes", "curl"), dependencies = TRUE)
remotes::install_local(path = ".")
remotes::install_local(path = ".", INSTALL_opts = "--no-test-load")
remotes::install_deps(dependencies = TRUE)
remotes::install_cran("covr")
remotes::install_cran("gridExtra")
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/cmdstan-tarball-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
run: |
remotes::install_deps(dependencies = TRUE)
remotes::install_cran("rcmdcheck")
remotes::install_local(path = ".")
remotes::install_local(path = ".", INSTALL_opts = "--no-test-load")
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
if (Sys.getenv("CMDSTAN_TEST_TARBALL_URL") == "latest") {
cmdstanr::install_cmdstan(cores = 2, overwrite = TRUE)
Expand Down
25 changes: 17 additions & 8 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,20 @@ CmdStanArgs <- R6::R6Class(
self$save_latent_dynamics <- save_latent_dynamics
self$using_tempdir <- is.null(output_dir)
self$model_variables <- model_variables
if (getRversion() < "3.5.0") {
if (os_is_wsl()) {
# Want to ensure that any files under WSL are written to a tempdir within
# WSL to avoid IO performance issues
self$output_dir <- ifelse(is.null(output_dir),
file.path(wsl_dir_prefix(), wsl_tempdir()),
wsl_safe_path(output_dir))
} else if (getRversion() < "3.5.0") {
self$output_dir <- output_dir %||% tempdir()
} else {
self$output_dir <- output_dir %||% tempdir(check = TRUE)
if (getRversion() < "3.5.0") {
self$output_dir <- output_dir %||% tempdir()
} else {
self$output_dir <- output_dir %||% tempdir(check = TRUE)
}
}
self$output_dir <- repair_path(self$output_dir)
self$output_basename <- output_basename
Expand Down Expand Up @@ -525,8 +535,7 @@ DiagnoseArgs <- R6::R6Class(
#' @return `TRUE` invisibly unless an error is thrown.
validate_cmdstan_args <- function(self) {
validate_exe_file(self$exe_file)

checkmate::assert_directory_exists(self$output_dir, access = "rw")
assert_dir_exists(self$output_dir, access = "rw")

# at least 1 run id (chain id)
checkmate::assert_integerish(self$proc_ids,
Expand All @@ -545,7 +554,7 @@ validate_cmdstan_args <- function(self) {
self$refresh <- as.integer(self$refresh)
}
if (!is.null(self$data_file)) {
checkmate::assert_file_exists(self$data_file, access = "r")
assert_file_exists(self$data_file, access = "r")
}
num_procs <- length(self$proc_ids)
validate_init(self$init, num_procs)
Expand Down Expand Up @@ -698,7 +707,7 @@ validate_optimize_args <- function(self) {
#' @return `TRUE` invisibly unless an error is thrown.
validate_generate_quantities_args <- function(self) {
if (!is.null(self$fitted_params)) {
checkmate::assert_file_exists(self$fitted_params, access = "r")
assert_file_exists(self$fitted_params, access = "r")
}

invisible(TRUE)
Expand Down Expand Up @@ -895,7 +904,7 @@ validate_init <- function(init, num_procs) {
"length 1 or number of chains.",
call. = FALSE)
}
checkmate::assert_file_exists(init, access = "r")
assert_file_exists(init, access = "r")
}

invisible(TRUE)
Expand Down Expand Up @@ -983,7 +992,7 @@ validate_metric_file <- function(metric_file, num_procs) {
return(invisible(TRUE))
}

checkmate::assert_file_exists(metric_file, access = "r")
assert_file_exists(metric_file, access = "r")

if (length(metric_file) != 1 && length(metric_file) != num_procs) {
stop(length(metric_file), " metric(s) provided. Must provide ",
Expand Down
8 changes: 4 additions & 4 deletions R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ read_cmdstan_csv <- function(files,
sampler_diagnostics = NULL,
format = getOption("cmdstanr_draws_format", NULL)) {
format <- assert_valid_draws_format(format)
checkmate::assert_file_exists(files, access = "r", extension = "csv")
assert_file_exists(files, access = "r", extension = "csv")
metadata <- NULL
warmup_draws <- list()
draws <- list()
Expand Down Expand Up @@ -237,7 +237,7 @@ read_cmdstan_csv <- function(files,
fread_cmd <- paste0(
grep_path_quotes,
" -v \"^#\" --color=never \"",
output_file,
wsl_safe_path(output_file, revert = TRUE),
"\""
)
} else {
Expand Down Expand Up @@ -556,7 +556,7 @@ for (method in unavailable_methods_CmdStanFit_CSV) {
#' mass matrix (or its diagonal depending on the metric).
#'
read_csv_metadata <- function(csv_file) {
checkmate::assert_file_exists(csv_file, access = "r", extension = "csv")
assert_file_exists(csv_file, access = "r", extension = "csv")
inv_metric_next <- FALSE
csv_file_info <- list()
csv_file_info$inv_metric <- NULL
Expand All @@ -579,7 +579,7 @@ read_csv_metadata <- function(csv_file) {
fread_cmd <- paste0(
grep_path_quotes,
" \"^[#a-zA-Z]\" --color=never \"",
csv_file,
wsl_safe_path(csv_file, revert = TRUE),
"\""
)
} else {
Expand Down
9 changes: 9 additions & 0 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ CmdStanFit <- R6::R6Class(
if (!is.null(private$model_methods_env_$model_ptr)) {
initialize_model_pointer(private$model_methods_env_, self$data_file(), 0)
}
# Need to update the output directory path to one that can be accessed
# from Windows, for the post-processing of results
self$runset$args$output_dir <- wsl_safe_path(self$runset$args$output_dir,
revert = TRUE)
invisible(self)
},
num_procs = function() {
Expand Down Expand Up @@ -303,6 +307,11 @@ CmdStanFit$set("public", name = "init", value = init)
#' }
#'
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
if (os_is_wsl()) {
stop("Additional model methods are not currently available with ",
"WSL CmdStan and will not be compiled",
call. = FALSE)
}
require_suggested_package("Rcpp")
require_suggested_package("RcppEigen")
if (length(private$model_methods_env_$hpp_code_) == 0) {
Expand Down
55 changes: 35 additions & 20 deletions R/install.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ install_cmdstan <- function(dir = NULL,
call. = FALSE)
wsl <- FALSE
} else {
Sys.setenv("CMDSTANR_USE_WSL" = 1)
.cmdstanr$WSL <- TRUE
}
} else {
.cmdstanr$WSL <- FALSE
}
if (check_toolchain) {
check_cmdstan_toolchain(fix = FALSE, quiet = quiet)
Expand All @@ -108,13 +110,13 @@ install_cmdstan <- function(dir = NULL,
}
}
if (is.null(dir)) {
dir <- cmdstan_default_install_path()
dir <- cmdstan_default_install_path(wsl = wsl)
if (!dir.exists(dir)) {
dir.create(dir, recursive = TRUE)
}
} else {
dir <- repair_path(dir)
checkmate::assert_directory_exists(dir, access = "rwx")
assert_dir_exists(dir, access = "rwx")
}
if (!is.null(version)) {
if (!is.null(release_url)) {
Expand All @@ -125,7 +127,6 @@ install_cmdstan <- function(dir = NULL,
release_url <- paste0("https://github.com/stan-dev/cmdstan/releases/download/v",
version, "/cmdstan-", version, cmdstan_arch_suffix(version), ".tar.gz")
}
wsl_prefix <- ifelse(isTRUE(wsl), "wsl-", "")
if (!is.null(release_url)) {
if (!endsWith(release_url, ".tar.gz")) {
stop(release_url, " is not a .tar.gz archive!",
Expand All @@ -137,14 +138,14 @@ install_cmdstan <- function(dir = NULL,
tar_name <- utils::tail(split_url[[1]], n = 1)
cmdstan_ver <- substr(tar_name, 0, nchar(tar_name) - 7)
tar_gz_file <- paste0(cmdstan_ver, ".tar.gz")
dir_cmdstan <- file.path(dir, paste0(wsl_prefix, cmdstan_ver))
dir_cmdstan <- file.path(dir, cmdstan_ver)
dest_file <- file.path(dir, tar_gz_file)
} else {
ver <- latest_released_version()
message("* Latest CmdStan release is v", ver)
cmdstan_ver <- paste0("cmdstan-", ver, cmdstan_arch_suffix(ver))
tar_gz_file <- paste0(cmdstan_ver, ".tar.gz")
dir_cmdstan <- file.path(dir, paste0(wsl_prefix, cmdstan_ver))
dir_cmdstan <- file.path(dir, cmdstan_ver)
message("* Installing CmdStan v", ver, " in ", dir_cmdstan)
message("* Downloading ", tar_gz_file, " from GitHub...")
download_url <- github_download_url(ver)
Expand All @@ -164,17 +165,34 @@ install_cmdstan <- function(dir = NULL,
stop("Download of CmdStan failed. Please try again.", call. = FALSE)
}
message("* Download complete")

message("* Unpacking archive...")
untar_rc <- utils::untar(
dest_file,
exdir = dir_cmdstan,
extras = "--strip-components 1"
)
if (untar_rc != 0) {
stop("Problem extracting tarball. Exited with return code: ", untar_rc, call. = FALSE)
if (wsl) {
# Significantly faster to use WSL to untar the downloaded archive, as there are
# similar IO issues accessing the WSL filesystem from windows
wsl_tar_gz_file <- gsub(paste0("//wsl$/", wsl_distro_name()), "",
dest_file, fixed = TRUE)
wsl_tar_gz_file <- wsl_safe_path(wsl_tar_gz_file)
untar_rc <- processx::run(
command = "wsl",
args = c("tar", "-xf", wsl_tar_gz_file, "-C",
gsub(tar_gz_file, "", wsl_tar_gz_file))
)
remove_rc <- processx::run(
command = "wsl",
args = c("rm", wsl_tar_gz_file)
)
} else {
untar_rc <- utils::untar(
dest_file,
exdir = dir_cmdstan,
extras = "--strip-components 1"
)
if (untar_rc != 0) {
stop("Problem extracting tarball. Exited with return code: ", untar_rc, call. = FALSE)
}
file.remove(dest_file)
}
file.remove(dest_file)

cmdstan_make_local(dir = dir_cmdstan, cpp_options = cpp_options, append = TRUE)
# Setting up native M1 compilation of CmdStan and its downstream libraries
if (is_rosetta2()) {
Expand All @@ -186,7 +204,7 @@ install_cmdstan <- function(dir = NULL,
append = TRUE
)
}
if (is_rtools42_toolchain() && !os_is_wsl()) {
if (is_rtools42_toolchain() && !wsl) {
cmdstan_make_local(
dir = dir_cmdstan,
cpp_options = list(
Expand Down Expand Up @@ -521,10 +539,7 @@ install_toolchain <- function(quiet = FALSE) {
}

check_wsl_toolchain <- function() {
wsl_inaccessible <- processx::run(command = "wsl",
args = "uname",
error_on_status = FALSE)
if (wsl_inaccessible$status) {
if (!wsl_installed()) {
stop("\n", "A WSL distribution is not installed or is not accessible.",
"\n", "Please see the Microsoft documentation for guidance on installing WSL: ",
"\n", "https://docs.microsoft.com/en-us/windows/wsl/install",
Expand Down
25 changes: 20 additions & 5 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ CmdStanModel <- R6::R6Class(
self$functions <- new.env()
self$functions$compiled <- FALSE
if (!is.null(stan_file)) {
checkmate::assert_file_exists(stan_file, access = "r", extension = "stan")
assert_file_exists(stan_file, access = "r", extension = "stan")
checkmate::assert_flag(compile)
private$stan_file_ <- absolute_path(stan_file)
private$stan_code_ <- readLines(stan_file)
Expand All @@ -250,7 +250,7 @@ CmdStanModel <- R6::R6Class(
ext <- if (os_is_windows() && !os_is_wsl()) "exe" else ""
private$exe_file_ <- repair_path(absolute_path(exe_file))
if (is.null(stan_file)) {
checkmate::assert_file_exists(private$exe_file_, access = "r", extension = ext)
assert_file_exists(private$exe_file_, access = "r", extension = ext)
private$model_name_ <- sub(" ", "_", strip_ext(basename(private$exe_file_)))
}
}
Expand Down Expand Up @@ -317,7 +317,7 @@ CmdStanModel <- R6::R6Class(
if (is.null(dir)) {
dir <- dirname(private$stan_file_)
}
checkmate::assert_directory_exists(dir, access = "r")
assert_dir_exists(dir, access = "r")
new_hpp_loc <- file.path(dir, paste0(strip_ext(basename(private$stan_file_)), ".hpp"))
file.copy(self$hpp_file(), new_hpp_loc, overwrite = TRUE)
file.remove(self$hpp_file())
Expand Down Expand Up @@ -471,7 +471,7 @@ compile <- function(quiet = TRUE,
}
if (!is.null(dir)) {
dir <- repair_path(dir)
checkmate::assert_directory_exists(dir, access = "rw")
assert_dir_exists(dir, access = "rw")
if (length(self$exe_file()) != 0) {
private$exe_file_ <- file.path(dir, basename(self$exe_file()))
}
Expand Down Expand Up @@ -524,6 +524,15 @@ compile <- function(quiet = TRUE,
}
}

if (os_is_wsl() && (compile_model_methods || compile_standalone)) {
warning("Additional model methods and standalone functions are not ",
"currently available with WSL CmdStan and will not be compiled",
call. = FALSE)
compile_model_methods <- FALSE
compile_standalone <- FALSE
compile_hessian_method <- FALSE
}

temp_stan_file <- tempfile(pattern = "model-", fileext = ".stan")
file.copy(self$stan_file(), temp_stan_file, overwrite = TRUE)
temp_file_no_ext <- strip_ext(temp_stan_file)
Expand Down Expand Up @@ -629,6 +638,12 @@ compile <- function(quiet = TRUE,
file.remove(exe)
}
file.copy(tmp_exe, exe, overwrite = TRUE)
if (os_is_wsl()) {
res <- processx::run(
command = "wsl",
args = c("chmod", "+x", wsl_safe_path(exe))
)
}
private$exe_file_ <- exe
private$cpp_options_ <- cpp_options
private$precompile_cpp_options_ <- NULL
Expand Down Expand Up @@ -1806,7 +1821,7 @@ cpp_options_to_compile_flags <- function(cpp_options) {
include_paths_stanc3_args <- function(include_paths = NULL) {
stancflags <- NULL
if (!is.null(include_paths)) {
checkmate::assert_directory_exists(include_paths, access = "r")
assert_dir_exists(include_paths, access = "r")
include_paths <- sapply(absolute_path(include_paths), wsl_safe_path)
paths_w_space <- grep(" ", include_paths)
include_paths[paths_w_space] <- paste0("'", include_paths[paths_w_space], "'")
Expand Down
Loading

0 comments on commit f8f6321

Please sign in to comment.