Skip to content

Commit

Permalink
[R-package] Promote number of threads to top-level argument in `light…
Browse files Browse the repository at this point in the history
…gbm()` and change default to number of cores (#4972)
  • Loading branch information
david-cortes authored Apr 1, 2022
1 parent 4ae3d13 commit 33eb037
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 11 deletions.
6 changes: 3 additions & 3 deletions .ci/test_r_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ if [[ $OS_NAME == "macos" ]]; then
fi
fi

# Manually install Depends and Imports libraries + 'knitr', 'rmarkdown', 'testthat'
# Manually install Depends and Imports libraries + 'knitr', 'RhpcBLASctl', 'rmarkdown', 'testthat'
# to avoid a CI-time dependency on devtools (for devtools::install_deps())
# NOTE: testthat is not required when running rchk
if [[ "${TASK}" == "r-rchk" ]]; then
packages="c('data.table', 'jsonlite', 'knitr', 'Matrix', 'R6', 'rmarkdown')"
packages="c('data.table', 'jsonlite', 'knitr', 'Matrix', 'R6', 'RhpcBLASctl', 'rmarkdown')"
else
packages="c('data.table', 'jsonlite', 'knitr', 'Matrix', 'R6', 'rmarkdown', 'testthat')"
packages="c('data.table', 'jsonlite', 'knitr', 'Matrix', 'R6', 'RhpcBLASctl', 'rmarkdown', 'testthat')"
fi
compile_from_source="both"
if [[ $OS_NAME == "macos" ]]; then
Expand Down
2 changes: 1 addition & 1 deletion .ci/test_r_package_solaris.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ apt-get install --no-install-recommends -y \

# installation of dependencies needs to happen before building the package,
# since `R CMD build` needs to install the package to build vignettes
Rscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'rmarkdown', 'rhub', 'testthat'), dependencies = c('Depends', 'Imports', 'LinkingTo'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())" || exit -1
Rscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'RhpcBLASctl', 'rmarkdown', 'rhub', 'testthat'), dependencies = c('Depends', 'Imports', 'LinkingTo'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())" || exit -1

sh build-cran-package.sh || exit -1

Expand Down
2 changes: 1 addition & 1 deletion .ci/test_r_package_valgrind.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

RDscriptvalgrind -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'rmarkdown', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())" || exit -1
RDscriptvalgrind -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'RhpcBLASctl', 'rmarkdown', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())" || exit -1
sh build-cran-package.sh \
--r-executable=RDvalgrind \
|| exit -1
Expand Down
2 changes: 1 addition & 1 deletion .ci/test_r_package_windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ Start-Process -FilePath Rtools.exe -NoNewWindow -Wait -ArgumentList "/VERYSILENT
Write-Output "Done installing Rtools"

Write-Output "Installing dependencies"
$packages = "c('data.table', 'jsonlite', 'knitr', 'Matrix', 'processx', 'R6', 'rmarkdown', 'testthat'), dependencies = c('Imports', 'Depends', 'LinkingTo')"
$packages = "c('data.table', 'jsonlite', 'knitr', 'Matrix', 'processx', 'R6', 'RhpcBLASctl', 'rmarkdown', 'testthat'), dependencies = c('Imports', 'Depends', 'LinkingTo')"
Run-R-Code-Redirect-Stderr "options(install.packages.check.source = 'no'); install.packages($packages, repos = '$env:CRAN_MIRROR', type = 'binary', lib = '$env:R_LIB_PATH', Ncpus = parallel::detectCores())" ; Check-Output $?

# MiKTeX and pandoc can be skipped on non-MinGW builds, since we don't
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/r_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ jobs:
- name: Install packages
shell: bash
run: |
RDscript${{ matrix.r_customization }} -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'rmarkdown', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())"
RDscript${{ matrix.r_customization }} -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'RhpcBLASctl', 'rmarkdown', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())"
sh build-cran-package.sh --r-executable=RD${{ matrix.r_customization }}
RD${{ matrix.r_customization }} CMD INSTALL lightgbm_*.tar.gz || exit -1
- name: Run tests with sanitizers
Expand Down Expand Up @@ -219,7 +219,7 @@ jobs:
shell: bash
run: |
export PATH=/opt/R-devel/bin/:${PATH}
Rscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'rmarkdown', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())"
Rscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'RhpcBLASctl', 'rmarkdown', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())"
sh build-cran-package.sh
R CMD check --as-cran --run-donttest lightgbm_*.tar.gz || exit -1
if grep -q -E "NOTE|WARNING|ERROR" lightgbm.Rcheck/00check.log; then
Expand Down
2 changes: 1 addition & 1 deletion .vsts-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ jobs:
R_LIB_PATH=~/Rlib
export R_LIBS=${R_LIB_PATH}
mkdir -p ${R_LIB_PATH}
RDscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'rmarkdown'), lib = '${R_LIB_PATH}', dependencies = c('Depends', 'Imports', 'LinkingTo'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())" || exit -1
RDscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'Matrix', 'RhpcBLASctl', 'rmarkdown'), lib = '${R_LIB_PATH}', dependencies = c('Depends', 'Imports', 'LinkingTo'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())" || exit -1
sh build-cran-package.sh --r-executable=RD || exit -1
mv lightgbm_${LGB_VER}.tar.gz $(Build.ArtifactStagingDirectory)/lightgbm-${LGB_VER}-r-cran.tar.gz
displayName: 'Build CRAN R-package'
Expand Down
2 changes: 2 additions & 0 deletions R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ VignetteBuilder: knitr
Suggests:
knitr,
processx,
RhpcBLASctl,
rmarkdown,
testthat
Depends:
Expand All @@ -61,6 +62,7 @@ Imports:
jsonlite (>= 1.0),
Matrix (>= 1.1-0),
methods,
parallel,
utils
SystemRequirements:
C++11
Expand Down
1 change: 1 addition & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ importFrom(graphics,barplot)
importFrom(graphics,par)
importFrom(jsonlite,fromJSON)
importFrom(methods,is)
importFrom(parallel,detectCores)
importFrom(stats,quantile)
importFrom(utils,modifyList)
importFrom(utils,read.delim)
Expand Down
2 changes: 1 addition & 1 deletion R-package/R/lgb.restore_handle.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' model <- lightgbm(
#' agaricus.train$data
#' , agaricus.train$label
#' , params = list(objective = "binary", nthreads = 1L)
#' , params = list(objective = "binary")
#' , nrounds = 5L
#' , verbose = 0)
#' fname <- tempfile(fileext="rds")
Expand Down
26 changes: 26 additions & 0 deletions R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@ NULL
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#objective}{
#' the "objective" item of the "Parameters" section of the documentation}.
#' @param init_score initial score is the base prediction lightgbm will boost from
#' @param num_threads Number of parallel threads to use. For best speed, this should be set to the number of
#' physical cores in the CPU - in a typical x86-64 machine, this corresponds to half the
#' number of maximum threads.
#'
#' Be aware that using too many threads can result in speed degradation in smaller datasets
#' (see the parameters documentation for more details).
#'
#' If passing zero, will use the default number of threads configured for OpenMP
#' (typically controlled through an environment variable \code{OMP_NUM_THREADS}).
#'
#' If passing \code{NULL} (the default), will try to use the number of physical cores in the
#' system, but be aware that getting the number of cores detected correctly requires package
#' \code{RhpcBLASctl} to be installed.
#'
#' This parameter gets overriden by \code{num_threads} and its aliases under \code{params}
#' if passed there.
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#' \itemize{
#' \item{\code{valids}: a list of \code{lgb.Dataset} objects, used for validation}
Expand Down Expand Up @@ -129,13 +145,23 @@ lightgbm <- function(data,
serializable = TRUE,
objective = "regression",
init_score = NULL,
num_threads = NULL,
...) {

# validate inputs early to avoid unnecessary computation
if (nrounds <= 0L) {
stop("nrounds should be greater than zero")
}

if (is.null(num_threads)) {
num_threads <- lgb.get.default.num.threads()
}
params <- lgb.check.wrapper_param(
main_param_name = "num_threads"
, params = params
, alternative_kwarg_value = num_threads
)

# Set data to a temporary variable
dtrain <- data

Expand Down
23 changes: 23 additions & 0 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,26 @@ lgb.check.wrapper_param <- function(main_param_name, params, alternative_kwarg_v
params[[main_param_name]] <- alternative_kwarg_value
return(params)
}

#' @importFrom parallel detectCores
lgb.get.default.num.threads <- function() {
if (requireNamespace("RhpcBLASctl", quietly = TRUE)) { # nolint
return(RhpcBLASctl::get_num_cores())
} else {
msg <- "Optional package 'RhpcBLASctl' not found."
cores <- 0L
if (Sys.info()["sysname"] != "Linux") {
cores <- parallel::detectCores(logical = FALSE)
if (is.na(cores) || cores < 0L) {
cores <- 0L
}
}
if (cores == 0L) {
msg <- paste(msg, "Will use default number of OpenMP threads.", sep = " ")
} else {
msg <- paste(msg, "Detection of CPU cores might not be accurate.", sep = " ")
}
warning(msg)
return(cores)
}
}
2 changes: 1 addition & 1 deletion R-package/man/lgb.restore_handle.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions R-package/man/lightgbm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 45 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -2928,6 +2928,51 @@ test_that("lightgbm() defaults to 'regression' objective if objective not otherw
expect_false(any(model_txt_lines == "objective=regression_l1"))
})

test_that("lightgbm() accepts 'num_threads' as either top-level argument or under params", {
bst <- lightgbm(
data = train$data
, label = train$label
, nrounds = 5L
, verbose = VERBOSITY
, num_threads = 1L
)
expect_equal(bst$params$num_threads, 1L)
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(grepl("\\[num_threads: 1\\]", model_txt_lines)))

bst <- lightgbm(
data = train$data
, label = train$label
, nrounds = 5L
, verbose = VERBOSITY
, params = list(num_threads = 1L)
)
expect_equal(bst$params$num_threads, 1L)
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(grepl("\\[num_threads: 1\\]", model_txt_lines)))

bst <- lightgbm(
data = train$data
, label = train$label
, nrounds = 5L
, verbose = VERBOSITY
, num_threads = 10L
, params = list(num_threads = 1L)
)
expect_equal(bst$params$num_threads, 1L)
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(grepl("\\[num_threads: 1\\]", model_txt_lines)))
})

test_that("lightgbm() accepts 'weight' and 'weights'", {
data(mtcars)
X <- as.matrix(mtcars[, -1L])
Expand Down

0 comments on commit 33eb037

Please sign in to comment.