Skip to content

Commit

Permalink
Merge pull request #47 from rstudio/updates
Browse files Browse the repository at this point in the history
Prepares release
  • Loading branch information
edgararuiz authored Apr 30, 2024
2 parents 4fc50e2 + 96aa880 commit c3291dd
Show file tree
Hide file tree
Showing 27 changed files with 418 additions and 325 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
^derby\.log$
^\.github$
^codecov\.yml$
^cran-comments\.md$
11 changes: 9 additions & 2 deletions .github/workflows/Coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ on:
pull_request:
branches: main

name: Tests
name: Coverage

jobs:
Coverage:
runs-on: ubuntu-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_KEEP_PKG_SOURCE: yes

steps:
- uses: actions/checkout@v3

Expand Down Expand Up @@ -46,6 +46,12 @@ jobs:
path: /home/runner/spark/spark-3.5.1-bin-hadoop3
key: sparklyr-spark-3.5.1-bin-hadoop3-2

- name: Install Spark (via sparklyr)
if: steps.cache-spark-2.outputs.cache-hit != 'true'
run: |
sparklyr::spark_install(version = "3.5")
shell: Rscript {0}

- name: Cache Scala
id: cache-scala
uses: actions/cache@v3
Expand Down Expand Up @@ -76,6 +82,7 @@ jobs:

- name: Test coverage
run: |
Sys.setenv("SPARK_VERSION" = "3.5")
Sys.setenv("CODE_COVERAGE" = "true")
devtools::load_all()
covr::codecov(
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/Tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ jobs:
path: /home/runner/spark/spark-3.4.2-bin-hadoop3
key: sparklyr-spark-3.4.2-bin-hadoop3-2

- name: Install Spark (via sparklyr)
if: steps.cache-spark-2.outputs.cache-hit != 'true'
run: |
sparklyr::spark_install(version = "3.4")
shell: Rscript {0}

- name: Cache Scala
id: cache-scala
uses: actions/cache@v3
Expand Down Expand Up @@ -75,6 +81,7 @@ jobs:

- name: R Tests
run: |
Sys.setenv("SPARK_VERSION" = "3.4")
devtools::load_all()
devtools::test()
shell: Rscript {0}
Expand Down
9 changes: 5 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: sparkxgb
Type: Package
Title: Interface for 'XGBoost' on 'Apache Spark'
Version: 0.1.2.9000
Version: 0.2
Authors@R: c(person("Kevin", "Kuo", email = "kevin.kuo@rstudio.com",
role = "aut", comment = c(ORCID = "0000-0001-7803-7901")),
person("Yitao", "Li", email = "yitaoli1990@gmail.com",
Expand All @@ -11,8 +11,8 @@ Authors@R: c(person("Kevin", "Kuo", email = "kevin.kuo@rstudio.com",
)
Maintainer: Edgar Ruiz <edgar@posit.co>
Description: A 'sparklyr' <https://spark.posit.co/> extension that provides an R
interface for 'XGBoost' <https://github.com/dmlc/xgboost> on 'Apache Spark'. 'XGBoost' is an
optimized distributed gradient boosting library.
interface for 'XGBoost' <https://github.com/dmlc/xgboost> on 'Apache Spark'.
'XGBoost' is an optimized distributed gradient boosting library.
License: Apache License (>= 2.0)
Encoding: UTF-8
LazyData: true
Expand All @@ -22,7 +22,8 @@ Imports:
sparklyr,
rlang,
magrittr,
vctrs
vctrs,
fs
RoxygenNote: 7.3.1
Suggests:
dplyr,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(xgboost_regressor,tbl_spark)
export("%>%")
export(xgboost_classifier)
export(xgboost_regressor)
import(fs)
importFrom(magrittr,"%>%")
importFrom(rlang,`%||%`)
importFrom(sparklyr,invoke)
Expand Down
18 changes: 13 additions & 5 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# sparkxgb 0.1.2.9000
# sparkxgb 0.2
### Fixes

- Avoids sending two deprecated parameters to XGBoost. The default arguments in
the R function are NULL, and it will return an error message if the call intends
Expand All @@ -8,12 +9,19 @@ to use them:

- Timeout Request Updates - No long supported since XGBoost version 1.7

- Adds setMissing param handler for XGBoostRegressor in the Scala code. The
`missing` parameter in `xgboost_regressor()` was not working.

- Creates the JAR for Scala version 2.12 only. The code is simple enough that
it does not seem to need multiple Spark version compiling. This also means that
Scala 2.11 is not supported at this time.

### Internal improvements

- Modernizes the entire `testthat` suite, it also expands it to provide more
coverage


- Modernizes and expands CI testing. The single CI job is now expanded to three:

- R package check, with no testing against the three major OS's
- `testthat` tests against Spark version 3.5
- Coverage testing, also against Spark version 3.5
Expand All @@ -22,8 +30,8 @@ coverage

- Improves download, preparation and building of the JAR

- Updates and cleans up the call that sets the Maven package to be used in the
Spark session
- Improves and cleans up the code that selects the JAR and Maven package to use.
It now depends on a text file created by the script the updates the JARS.

- Updates Roxygen and `testthat` versions

Expand Down
10 changes: 5 additions & 5 deletions R/compact-forge.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ cast_string_list <- function(x, allow_null = FALSE, ...) {
cast_list(x, character(), allow_null = allow_null)
}

cast_choice <- function(x, choices, error_arg = rlang::caller_arg(x),
cast_choice <- function(x, choices, error_arg = rlang::caller_arg(x),
error_call = rlang::caller_env(), ...) {
rlang::arg_match(x, choices, error_arg = error_arg, error_call = error_call)
}
Expand All @@ -68,19 +68,19 @@ gte <- function(l) {

bounded <- function(l = NULL, u = NULL, incl_lower = TRUE, incl_upper = TRUE) {
if (is.null(l) && is.null(u)) stop("At least one of `l` or `u` must be specified.", call. = FALSE)

lower_bound <- if (!is.null(l)) {
if (incl_lower) gte(l) else gt(l)
} else {
function() TRUE
}

upper_bound <- if (!is.null(u)) {
if (incl_upper) lte(u) else lt(u)
} else {
function() TRUE
}

function(x) lower_bound(x) && upper_bound(x)
}

Expand All @@ -92,4 +92,4 @@ lt <- function(u) {
lte <- function(u) {
force(u)
function(x) all(x <= u)
}
}
28 changes: 20 additions & 8 deletions R/dependencies.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
spark_dependencies <- function(spark_version, scala_version, ...) {
if (scala_version != "2.12") {
stop(sprintf("Unsupported Scala version '%s'.", scala_version))
}
sparklyr::spark_dependency(
jars = system.file("java/sparkxgb-3.0-2.12.jar", package = "sparkxgb"),
packages = (
if (scala_version == "2.12") {
"ml.dmlc:xgboost4j-spark_2.12:2.0.3"
} else {
stop(sprintf("Unsupported Scala version '%s'.", scala_version))
}
)
jars = package_file("java/sparkxgb-3.0-2.12.jar"),
packages = readLines(package_file("maven/scala_212.txt"))
)
}

.onLoad <- function(libname, pkgname) {
sparklyr::register_extension(pkgname)
}

package_file <- function(...) {
default_file <- path(...)
inst_file <- path("inst", default_file)
pkg_file <- NULL
if (file_exists(inst_file)) {
pkg_file <- inst_file
} else {
pkg_file <- system.file(default_file, package = "sparkxgb")
}
if (!file_exists(pkg_file)) {
stop(paste0("'", default_file, "' not found"))
}
pkg_file
}
3 changes: 2 additions & 1 deletion R/imports.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#' @importFrom rlang `%||%`
#' @importFrom sparklyr invoke random_string spark_connection jobj_set_param
NULL
#' @import fs
NULL
20 changes: 11 additions & 9 deletions R/xgboost_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ xgboost_classifier.spark_connection <- function(
)

args <- validator_xgboost_classifier(args)

xg_unsupported(args)

stage_class <- "ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier"
Expand Down Expand Up @@ -164,7 +164,7 @@ xgboost_classifier.spark_connection <- function(

if (!is.nan(args[["missing"]])) {
jobj <- sparklyr::invoke_static(
x, "sparkxgb.Utils", "setMissingParam", jobj, args[["missing"]]
x, "sparkxgb.Utils", "setMissingParamClass", jobj, args[["missing"]]
)
}

Expand Down Expand Up @@ -411,14 +411,16 @@ ml_feature_importances.ml_model_xgboost_classification <- function(model, ...) {
}

xg_unsupported <- function(args) {
if(!is.null(args$sketch_eps)) {
stop("As of XGBoost version 1.6.0, 'Sketch EPS'",
" is no longer supported, consider using 'Max Bins'"
)
if (!is.null(args$sketch_eps)) {
stop(
"As of XGBoost version 1.6.0, 'Sketch EPS'",
" is no longer supported, consider using 'Max Bins'"
)
}
if(!is.null(args$timeout_request_workers)) {
stop("As of XGBoost version 1.7.0, 'Timeout Request Workers'",
" is no longer supported"
if (!is.null(args$timeout_request_workers)) {
stop(
"As of XGBoost version 1.7.0, 'Timeout Request Workers'",
" is no longer supported"
)
}
invisible()
Expand Down
Loading

0 comments on commit c3291dd

Please sign in to comment.