diff --git a/.github/dependabot.yml b/.github/dependabot.yml index c03a52c605c9..06badec5f2e2 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -8,7 +8,7 @@ updates: - package-ecosystem: "maven" directory: "/jvm-packages" schedule: - interval: "daily" + interval: "monthly" - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j" schedule: @@ -16,11 +16,11 @@ updates: - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-gpu" schedule: - interval: "daily" + interval: "monthly" - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-example" schedule: - interval: "daily" + interval: "monthly" - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-spark" schedule: @@ -28,4 +28,8 @@ updates: - package-ecosystem: "maven" directory: "/jvm-packages/xgboost4j-spark-gpu" schedule: - interval: "daily" + interval: "monthly" + - package-ecosystem: "github-actions" + directory: / + schedule: + interval: "monthly" diff --git a/.github/workflows/i386.yml b/.github/workflows/i386.yml index 4a4d65b25b61..1c4e98010310 100644 --- a/.github/workflows/i386.yml +++ b/.github/workflows/i386.yml @@ -5,6 +5,10 @@ on: [push, pull_request] permissions: contents: read # to fetch code (actions/checkout) +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: build-32bit: name: Build 32-bit @@ -15,7 +19,7 @@ jobs: ports: - 5000:5000 steps: - - uses: actions/checkout@v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - name: Set up Docker Buildx diff --git a/.github/workflows/jvm_tests.yml b/.github/workflows/jvm_tests.yml index bbded088387f..9ef314ca5b0b 100644 --- a/.github/workflows/jvm_tests.yml +++ b/.github/workflows/jvm_tests.yml @@ -5,6 +5,10 @@ on: [push, pull_request] permissions: contents: read # to fetch code (actions/checkout) +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: test-with-jvm: name: Test JVM on OS ${{ matrix.os }} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 133e151e5e4f..4755f9aaaad8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -9,6 +9,10 @@ on: [push, pull_request] permissions: contents: read # to fetch code (actions/checkout) +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: gtest-cpu: @@ -19,7 +23,7 @@ jobs: matrix: os: [macos-11] steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - name: Install system packages @@ -45,7 +49,7 @@ jobs: matrix: os: [ubuntu-latest] steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - name: Install system packages @@ -72,10 +76,10 @@ jobs: os: [ubuntu-latest] python-version: ["3.8"] steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: mamba-org/provision-with-micromamba@f347426e5745fe3dfc13ec5baf20496990d0281f # v14 + - uses: mamba-org/provision-with-micromamba@3c96c0c27676490c63c18bc81f5c51895ac3e0e6 # v16 with: cache-downloads: true cache-env: true @@ -114,10 +118,10 @@ jobs: os: ["ubuntu-latest"] python-version: ["3.8"] steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: mamba-org/provision-with-micromamba@f347426e5745fe3dfc13ec5baf20496990d0281f # v14 + - uses: mamba-org/provision-with-micromamba@3c96c0c27676490c63c18bc81f5c51895ac3e0e6 # v16 with: cache-downloads: true cache-env: true @@ -171,7 +175,7 @@ jobs: runs-on: ubuntu-latest name: Code linting for C++ steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 diff --git a/.github/workflows/python_tests.yml b/.github/workflows/python_tests.yml index 3fbcc7a01acf..f0cad6382d87 100644 --- a/.github/workflows/python_tests.yml +++ b/.github/workflows/python_tests.yml @@ -9,6 +9,10 @@ defaults: run: shell: bash -l {0} +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: python-mypy-lint: runs-on: ubuntu-latest @@ -17,10 +21,10 @@ jobs: matrix: os: [ubuntu-latest] steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: mamba-org/provision-with-micromamba@f347426e5745fe3dfc13ec5baf20496990d0281f # v14 + - uses: mamba-org/provision-with-micromamba@3c96c0c27676490c63c18bc81f5c51895ac3e0e6 # v16 with: cache-downloads: true cache-env: true @@ -48,10 +52,10 @@ jobs: matrix: os: [ubuntu-latest] steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: mamba-org/provision-with-micromamba@f347426e5745fe3dfc13ec5baf20496990d0281f # v14 + - uses: mamba-org/provision-with-micromamba@3c96c0c27676490c63c18bc81f5c51895ac3e0e6 # v16 with: cache-downloads: true cache-env: true @@ -80,14 +84,14 @@ jobs: os: [macos-11, windows-latest] python-version: ["3.8"] steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - name: Install osx system dependencies if: matrix.os == 'macos-11' run: | brew install ninja libomp - - uses: conda-incubator/setup-miniconda@35d1405e78aa3f784fe3ce9a2eb378d5eeb62169 # v2.1.1 + - uses: conda-incubator/setup-miniconda@a4260408e20b96e80095f42ff7f1a15b27dd94ca # v3.0.4 with: auto-update-conda: true python-version: ${{ matrix.python-version }} @@ -118,11 +122,11 @@ jobs: - {os: macos-11} steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: mamba-org/provision-with-micromamba@f347426e5745fe3dfc13ec5baf20496990d0281f # v14 + - uses: mamba-org/provision-with-micromamba@3c96c0c27676490c63c18bc81f5c51895ac3e0e6 # v16 with: cache-downloads: true cache-env: true @@ -170,11 +174,11 @@ jobs: - {os: windows-latest, python-version: '3.8'} steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: conda-incubator/setup-miniconda@35d1405e78aa3f784fe3ce9a2eb378d5eeb62169 # v2.1.1 + - uses: conda-incubator/setup-miniconda@a4260408e20b96e80095f42ff7f1a15b27dd94ca # v3.0.4 with: auto-update-conda: true python-version: ${{ matrix.config.python-version }} @@ -214,11 +218,11 @@ jobs: - {os: ubuntu-latest, python-version: "3.8"} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: mamba-org/provision-with-micromamba@f347426e5745fe3dfc13ec5baf20496990d0281f # v14 + - uses: mamba-org/provision-with-micromamba@3c96c0c27676490c63c18bc81f5c51895ac3e0e6 # v16 with: cache-downloads: true cache-env: true @@ -266,11 +270,11 @@ jobs: - {os: ubuntu-latest, python-version: "3.8"} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: mamba-org/provision-with-micromamba@f347426e5745fe3dfc13ec5baf20496990d0281f # v14 + - uses: mamba-org/provision-with-micromamba@3c96c0c27676490c63c18bc81f5c51895ac3e0e6 # v16 with: cache-downloads: true cache-env: true @@ -305,7 +309,7 @@ jobs: os: [ubuntu-latest] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' diff --git a/.github/workflows/python_wheels.yml b/.github/workflows/python_wheels.yml index 129ab805f753..090b1f830213 100644 --- a/.github/workflows/python_wheels.yml +++ b/.github/workflows/python_wheels.yml @@ -25,7 +25,7 @@ jobs: - os: macos-14 platform_id: macosx_arm64 steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - uses: conda-incubator/setup-miniconda@v3.0.4 diff --git a/.github/workflows/r_nold.yml b/.github/workflows/r_nold.yml index a014c9138493..887470190035 100644 --- a/.github/workflows/r_nold.yml +++ b/.github/workflows/r_nold.yml @@ -10,6 +10,10 @@ on: permissions: contents: read # to fetch code (actions/checkout) +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: test-R-noLD: if: github.event.comment.body == '/gha run r-nold-test' && contains('OWNER,MEMBER,COLLABORATOR', github.event.comment.author_association) @@ -23,7 +27,7 @@ jobs: run: | apt update && apt install libcurl4-openssl-dev libssl-dev libssh2-1-dev libgit2-dev libglpk-dev libxml2-dev libharfbuzz-dev libfribidi-dev git -y - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' diff --git a/.github/workflows/r_tests.yml b/.github/workflows/r_tests.yml index 1ed5ca20b777..f3d83b823aff 100644 --- a/.github/workflows/r_tests.yml +++ b/.github/workflows/r_tests.yml @@ -8,6 +8,10 @@ env: permissions: contents: read # to fetch code (actions/checkout) +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: lintr: runs-on: ${{ matrix.config.os }} @@ -21,11 +25,11 @@ jobs: RSPM: ${{ matrix.config.rspm }} steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: r-lib/actions/setup-r@b7e68d63e51bdf225997973e2add36d551f60f02 # v2.8.7 + - uses: r-lib/actions/setup-r@929c772977a3a13c8733b363bf5a2f685c25dd91 # v2.9.0 with: r-version: ${{ matrix.config.r }} @@ -33,8 +37,8 @@ jobs: uses: actions/cache@937d24475381cd9c75ae6db12cb4e79714b926ed # v3.0.11 with: path: ${{ env.R_LIBS_USER }} - key: ${{ runner.os }}-r-${{ matrix.config.r }}-6-${{ hashFiles('R-package/DESCRIPTION') }} - restore-keys: ${{ runner.os }}-r-${{ matrix.config.r }}-6-${{ hashFiles('R-package/DESCRIPTION') }} + key: ${{ runner.os }}-r-${{ matrix.config.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }} + restore-keys: ${{ runner.os }}-r-${{ matrix.config.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }} - name: Install dependencies shell: Rscript {0} @@ -46,7 +50,7 @@ jobs: MAKEFLAGS="-j$(nproc)" R CMD INSTALL R-package/ Rscript tests/ci_build/lint_r.R $(pwd) - test-R-on-Windows: + test-Rpkg: runs-on: ${{ matrix.config.os }} name: Test R on OS ${{ matrix.config.os }}, R ${{ matrix.config.r }}, Compiler ${{ matrix.config.compiler }}, Build ${{ matrix.config.build }} strategy: @@ -54,16 +58,22 @@ jobs: matrix: config: - {os: windows-latest, r: 'release', compiler: 'mingw', build: 'autotools'} + - {os: ubuntu-latest, r: 'release', compiler: 'none', build: 'cmake'} env: R_REMOTES_NO_ERRORS_FROM_WARNINGS: true RSPM: ${{ matrix.config.rspm }} steps: - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - name: Install system dependencies + run: | + sudo apt update + sudo apt install libcurl4-openssl-dev libssl-dev libssh2-1-dev libgit2-dev libglpk-dev libxml2-dev libharfbuzz-dev libfribidi-dev + if: matrix.config.os == 'ubuntu-latest' + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - - uses: r-lib/actions/setup-r@b7e68d63e51bdf225997973e2add36d551f60f02 # v2.8.7 + - uses: r-lib/actions/setup-r@929c772977a3a13c8733b363bf5a2f685c25dd91 # v2.9.0 with: r-version: ${{ matrix.config.r }} @@ -71,8 +81,8 @@ jobs: uses: actions/cache@937d24475381cd9c75ae6db12cb4e79714b926ed # v3.0.11 with: path: ${{ env.R_LIBS_USER }} - key: ${{ runner.os }}-r-${{ matrix.config.r }}-6-${{ hashFiles('R-package/DESCRIPTION') }} - restore-keys: ${{ runner.os }}-r-${{ matrix.config.r }}-6-${{ hashFiles('R-package/DESCRIPTION') }} + key: ${{ runner.os }}-r-${{ matrix.config.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }} + restore-keys: ${{ runner.os }}-r-${{ matrix.config.r }}-7-${{ hashFiles('R-package/DESCRIPTION') }} - uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 with: @@ -89,12 +99,18 @@ jobs: - name: Test R run: | python tests/ci_build/test_r_package.py --compiler='${{ matrix.config.compiler }}' --build-tool="${{ matrix.config.build }}" --task=check + if: matrix.config.compiler != 'none' + + - name: Test R + run: | + python tests/ci_build/test_r_package.py --build-tool="${{ matrix.config.build }}" --task=check + if: matrix.config.compiler == 'none' test-R-on-Debian: name: Test R package on Debian runs-on: ubuntu-latest container: - image: rhub/debian-gcc-devel + image: rhub/debian-gcc-release steps: - name: Install system dependencies @@ -107,21 +123,21 @@ jobs: run: | git config --global --add safe.directory "${GITHUB_WORKSPACE}" - - uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - name: Install dependencies shell: bash -l {0} run: | - /tmp/R-devel/bin/Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')" + Rscript -e "source('./R-package/tests/helper_scripts/install_deps.R')" - name: Test R shell: bash -l {0} run: | - python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --build-tool=autotools --task=check + python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --build-tool=autotools --task=check - - uses: dorny/paths-filter@v2 + - uses: dorny/paths-filter@v3 id: changes with: filters: | @@ -131,4 +147,4 @@ jobs: - name: Run document check if: steps.changes.outputs.r_package == 'true' run: | - python3 tests/ci_build/test_r_package.py --r=/tmp/R-devel/bin/R --task=doc + python3 tests/ci_build/test_r_package.py --r=/usr/bin/R --task=doc diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index 78cde0a43cb2..4651e2ac0dff 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -22,12 +22,12 @@ jobs: steps: - name: "Checkout code" - uses: actions/checkout@a12a3943b4bdde767164f792f33f40b04645d846 # tag=v3.0.0 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@08b4669551908b1024bb425080c797723083c031 # tag=v2.2.0 + uses: ossf/scorecard-action@dc50aa9510b46c811795eb24b2f1ba02a914e534 # v2.3.3 with: results_file: results.sarif results_format: sarif @@ -41,7 +41,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce # tag=v3.1.2 + uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 with: name: SARIF file path: results.sarif @@ -49,6 +49,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@7b6664fa89524ee6e3c3e9749402d5afd69b3cd8 # tag=v2.14.1 + uses: github/codeql-action/upload-sarif@83a02f7883b12e0e4e1a146174f5e2292a01e601 # v2.16.4 with: sarif_file: results.sarif diff --git a/.github/workflows/update_rapids.yml b/.github/workflows/update_rapids.yml index 395a42148c23..9f9c85f62e28 100644 --- a/.github/workflows/update_rapids.yml +++ b/.github/workflows/update_rapids.yml @@ -3,7 +3,7 @@ name: update-rapids on: workflow_dispatch: schedule: - - cron: "0 20 * * *" # Run once daily + - cron: "0 20 * * 1" # Run once weekly permissions: pull-requests: write @@ -25,14 +25,14 @@ jobs: name: Check latest RAPIDS runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: 'true' - name: Check latest RAPIDS and update conftest.sh run: | bash tests/buildkite/update-rapids.sh - name: Create Pull Request - uses: peter-evans/create-pull-request@v5 + uses: peter-evans/create-pull-request@v6 if: github.ref == 'refs/heads/master' with: add-paths: | diff --git a/CMakeLists.txt b/CMakeLists.txt index dbfa1cdc225b..c69b0d2a3dc7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,7 +69,6 @@ option(USE_DMLC_GTEST "Use google tests bundled with dmlc-core submodule" OFF) option(USE_DEVICE_DEBUG "Generate CUDA device debug info." OFF) option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF) set(NVTX_HEADER_DIR "" CACHE PATH "Path to the stand-alone nvtx header") -option(RABIT_MOCK "Build rabit with mock" OFF) option(HIDE_CXX_SYMBOLS "Build shared library and hide all C++ symbols" OFF) option(KEEP_BUILD_ARTIFACTS_IN_BINARY_DIR "Output build artifacts in CMake binary dir" OFF) ## CUDA @@ -282,9 +281,6 @@ if(MSVC) endif() endif() -# rabit -add_subdirectory(rabit) - # core xgboost add_subdirectory(${xgboost_SOURCE_DIR}/src) target_link_libraries(objxgboost PUBLIC dmlc) diff --git a/NEWS.md b/NEWS.md index 43019d877cd0..b067c8e3ca88 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2101,7 +2101,7 @@ This release marks a major milestone for the XGBoost project. ## v0.90 (2019.05.18) ### XGBoost Python package drops Python 2.x (#4379, #4381) -Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.org/). +Python 2.x is reaching its end-of-life at the end of this year. [Many scientific Python packages are now moving to drop Python 2.x](https://python3statement.github.io/). ### XGBoost4J-Spark now requires Spark 2.4.x (#4377) * Spark 2.3 is reaching its end-of-life soon. See discussion at #4389. diff --git a/R-package/CMakeLists.txt b/R-package/CMakeLists.txt index d3a69abc278e..37c5dbf4c1ed 100644 --- a/R-package/CMakeLists.txt +++ b/R-package/CMakeLists.txt @@ -26,7 +26,6 @@ endif() target_compile_definitions( xgboost-r PUBLIC -DXGBOOST_STRICT_R_MODE=1 - -DXGBOOST_CUSTOMIZE_GLOBAL_PRNG=1 -DDMLC_LOG_BEFORE_THROW=0 -DDMLC_DISABLE_STDIN=1 -DDMLC_LOG_CUSTOMIZE=1 diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 580d1f87325f..c9e085e77e0a 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -20,15 +20,9 @@ export("xgb.attr<-") export("xgb.attributes<-") export("xgb.config<-") export("xgb.parameters<-") -export(cb.cv.predict) -export(cb.early.stop) -export(cb.evaluation.log) -export(cb.gblinear.history) -export(cb.print.evaluation) -export(cb.reset.parameters) -export(cb.save.model) export(getinfo) export(setinfo) +export(xgb.Callback) export(xgb.DMatrix) export(xgb.DMatrix.hasinfo) export(xgb.DMatrix.save) @@ -39,6 +33,13 @@ export(xgb.QuantileDMatrix) export(xgb.QuantileDMatrix.from_iterator) export(xgb.attr) export(xgb.attributes) +export(xgb.cb.cv.predict) +export(xgb.cb.early.stop) +export(xgb.cb.evaluation.log) +export(xgb.cb.gblinear.history) +export(xgb.cb.print.evaluation) +export(xgb.cb.reset.parameters) +export(xgb.cb.save.model) export(xgb.config) export(xgb.copy.Booster) export(xgb.create.features) @@ -72,14 +73,10 @@ export(xgb.slice.DMatrix) export(xgb.train) export(xgboost) import(methods) +importClassesFrom(Matrix,CsparseMatrix) importClassesFrom(Matrix,dgCMatrix) importClassesFrom(Matrix,dgRMatrix) -importClassesFrom(Matrix,dgeMatrix) -importFrom(Matrix,colSums) importFrom(Matrix,sparse.model.matrix) -importFrom(Matrix,sparseMatrix) -importFrom(Matrix,sparseVector) -importFrom(Matrix,t) importFrom(data.table,":=") importFrom(data.table,as.data.table) importFrom(data.table,data.table) @@ -101,6 +98,7 @@ importFrom(methods,new) importFrom(stats,coef) importFrom(stats,median) importFrom(stats,predict) +importFrom(stats,sd) importFrom(stats,variable.names) importFrom(utils,head) importFrom(utils,object.size) diff --git a/R-package/R/callbacks.R b/R-package/R/callbacks.R index 02e0a7cd4b8e..39734ab092d3 100644 --- a/R-package/R/callbacks.R +++ b/R-package/R/callbacks.R @@ -1,478 +1,833 @@ -#' Callback closures for booster training. +.reserved_cb_names <- c("names", "class", "call", "params", "niter", "nfeatures", "folds") + +#' @title XGBoost Callback Constructor +#' @description Constructor for defining the structure of callback functions that can be executed +#' at different stages of model training (before / after training, before / after each boosting +#' iteration). +#' @param cb_name Name for the callback. #' -#' These are used to perform various service tasks either during boosting iterations or at the end. -#' This approach helps to modularize many of such tasks without bloating the main training methods, -#' and it offers . +#' If the callback produces some non-NULL result (from executing the function passed under +#' `f_after_training`), that result will be added as an R attribute to the resulting booster +#' (or as a named element in the result of CV), with the attribute name specified here. #' -#' @details -#' By default, a callback function is run after each boosting iteration. -#' An R-attribute \code{is_pre_iteration} could be set for a callback to define a pre-iteration function. +#' Names of callbacks must be unique - i.e. there cannot be two callbacks with the same name. +#' @param env An environment object that will be passed to the different functions in the callback. +#' Note that this environment will not be shared with other callbacks. +#' @param f_before_training A function that will be executed before the training has started. #' -#' When a callback function has \code{finalize} parameter, its finalizer part will also be run after -#' the boosting is completed. +#' If passing `NULL` for this or for the other function inputs, then no function will be executed. #' -#' WARNING: side-effects!!! Be aware that these callback functions access and modify things in -#' the environment from which they are called from, which is a fairly uncommon thing to do in R. +#' If passing a function, it will be called with parameters supplied as non-named arguments +#' matching the function signatures that are shown in the default value for each function argument. +#' @param f_before_iter A function that will be executed before each boosting round. #' -#' To write a custom callback closure, make sure you first understand the main concepts about R environments. -#' Check either R documentation on \code{\link[base]{environment}} or the -#' \href{http://adv-r.had.co.nz/Environments.html}{Environments chapter} from the "Advanced R" -#' book by Hadley Wickham. Further, the best option is to read the code of some of the existing callbacks - -#' choose ones that do something similar to what you want to achieve. Also, you would need to get familiar -#' with the objects available inside of the \code{xgb.train} and \code{xgb.cv} internal environments. +#' This function can signal whether the training should be finalized or not, by outputting +#' a value that evaluates to `TRUE` - i.e. if the output from the function provided here at +#' a given round is `TRUE`, then training will be stopped before the current iteration happens. #' -#' @seealso -#' \code{\link{cb.print.evaluation}}, -#' \code{\link{cb.evaluation.log}}, -#' \code{\link{cb.reset.parameters}}, -#' \code{\link{cb.early.stop}}, -#' \code{\link{cb.save.model}}, -#' \code{\link{cb.cv.predict}}, -#' \code{\link{xgb.train}}, -#' \code{\link{xgb.cv}} +#' Return values of `NULL` will be interpreted as `FALSE`. +#' @param f_after_iter A function that will be executed after each boosting round. #' -#' @name callbacks -NULL - -# -# Callbacks ------------------------------------------------------------------- -# - -#' Callback closure for printing the result of evaluation +#' This function can signal whether the training should be finalized or not, by outputting +#' a value that evaluates to `TRUE` - i.e. if the output from the function provided here at +#' a given round is `TRUE`, then training will be stopped at that round. #' -#' @param period results would be printed every number of periods -#' @param showsd whether standard deviations should be printed (when available) +#' Return values of `NULL` will be interpreted as `FALSE`. +#' @param f_after_training A function that will be executed after training is finished. #' -#' @details -#' The callback function prints the result of evaluation at every \code{period} iterations. -#' The initial and the last iteration's evaluations are always printed. +#' This function can optionally output something non-NULL, which will become part of the R +#' attributes of the booster (assuming one passes `keep_extra_attributes=TRUE` to \link{xgb.train}) +#' under the name supplied for parameter `cb_name` imn the case of \link{xgb.train}; or a part +#' of the named elements in the result of \link{xgb.cv}. +#' @return An `xgb.Callback` object, which can be passed to \link{xgb.train} or \link{xgb.cv}. +#' @details Arguments that will be passed to the supplied functions are as follows:\itemize{ #' -#' Callback function expects the following values to be set in its calling frame: -#' \code{bst_evaluation} (also \code{bst_evaluation_err} when available), -#' \code{iteration}, -#' \code{begin_iteration}, -#' \code{end_iteration}. +#' \item env The same environment that is passed under argument `env`. #' -#' @seealso -#' \code{\link{callbacks}} +#' It may be modified by the functions in order to e.g. keep tracking of what happens +#' across iterations or similar. #' -#' @export -cb.print.evaluation <- function(period = 1, showsd = TRUE) { - - callback <- function(env = parent.frame()) { - if (length(env$bst_evaluation) == 0 || - period == 0 || - NVL(env$rank, 0) != 0) - return() - - i <- env$iteration - if ((i - 1) %% period == 0 || - i == env$begin_iteration || - i == env$end_iteration) { - stdev <- if (showsd) env$bst_evaluation_err else NULL - msg <- .format_eval_string(i, env$bst_evaluation, stdev) - cat(msg, '\n') - } - } - attr(callback, 'call') <- match.call() - attr(callback, 'name') <- 'cb.print.evaluation' - callback -} - - -#' Callback closure for logging the evaluation history +#' This environment is only used by the functions supplied to the callback, and will +#' not be kept after the model fitting function terminates (see parameter `f_after_training`). #' -#' @details -#' This callback function appends the current iteration evaluation results \code{bst_evaluation} -#' available in the calling parent frame to the \code{evaluation_log} list in a calling frame. +#' \item model The booster object when using \link{xgb.train}, or the folds when using +#' \link{xgb.cv}. +#' +#' For \link{xgb.cv}, folds are a list with a structure as follows:\itemize{ +#' \item `dtrain`: The training data for the fold (as an `xgb.DMatrix` object). +#' \item `bst`: Rhe `xgb.Booster` object for the fold. +#' \item `evals`: A list containing two DMatrices, with names `train` and `test` +#' (`test` is the held-out data for the fold). +#' \item `index`: The indices of the hold-out data for that fold (base-1 indexing), +#' from which the `test` entry in `evals` was obtained. +#' } #' -#' The finalizer callback (called with \code{finalize = TURE} in the end) converts -#' the \code{evaluation_log} list into a final data.table. +#' This object should \bold{not} be in-place modified in ways that conflict with the +#' training (e.g. resetting the parameters for a training update in a way that resets +#' the number of rounds to zero in order to overwrite rounds). #' -#' The iteration evaluation result \code{bst_evaluation} must be a named numeric vector. +#' Note that any R attributes that are assigned to the booster during the callback functions, +#' will not be kept thereafter as the booster object variable is not re-assigned during +#' training. It is however possible to set C-level attributes of the booster through +#' \link{xgb.attr} or \link{xgb.attributes}, which should remain available for the rest +#' of the iterations and after the training is done. #' -#' Note: in the column names of the final data.table, the dash '-' character is replaced with -#' the underscore '_' in order to make the column names more like regular R identifiers. +#' For keeping variables across iterations, it's recommended to use `env` instead. +#' \item data The data to which the model is being fit, as an `xgb.DMatrix` object. +#' +#' Note that, for \link{xgb.cv}, this will be the full data, while data for the specific +#' folds can be found in the `model` object. +#' +#' \item evals The evaluation data, as passed under argument `evals` to +#' \link{xgb.train}. +#' +#' For \link{xgb.cv}, this will always be `NULL`. +#' +#' \item begin_iteration Index of the first boosting iteration that will be executed +#' (base-1 indexing). +#' +#' This will typically be '1', but when using training continuation, depending on the +#' parameters for updates, boosting rounds will be continued from where the previous +#' model ended, in which case this will be larger than 1. +#' +#' \item end_iteration Index of the last boostign iteration that will be executed +#' (base-1 indexing, inclusive of this end). +#' +#' It should match with argument `nrounds` passed to \link{xgb.train} or \link{xgb.cv}. +#' +#' Note that boosting might be interrupted before reaching this last iteration, for +#' example by using the early stopping callback \link{xgb.cb.early.stop}. +#' +#' \item iteration Index of the iteration number that is being executed (first iteration +#' will be the same as parameter `begin_iteration`, then next one will add +1, and so on). +#' +#' \item iter_feval Evaluation metrics for `evals` that were supplied, either +#' determined by the objective, or by parameter `feval`. +#' +#' For \link{xgb.train}, this will be a named vector with one entry per element in +#' `evals`, where the names are determined as 'evals name' + '-' + 'metric name' - for +#' example, if `evals` contains an entry named "tr" and the metric is "rmse", +#' this will be a one-element vector with name "tr-rmse". +#' +#' For \link{xgb.cv}, this will be a 2d matrix with dimensions `[length(evals), nfolds]`, +#' where the row names will follow the same naming logic as the one-dimensional vector +#' that is passed in \link{xgb.train}. +#' +#' Note that, internally, the built-in callbacks such as \link{xgb.cb.print.evaluation} summarize +#' this table by calculating the row-wise means and standard deviations. +#' +#' \item final_feval The evaluation results after the last boosting round is executed +#' (same format as `iter_feval`, and will be the exact same input as passed under +#' `iter_feval` to the last round that is executed during model fitting). +#' +#' \item prev_cb_res Result from a previous run of a callback sharing the same name +#' (as given by parameter `cb_name`) when conducting training continuation, if there +#' was any in the booster R attributes. +#' +#' Some times, one might want to append the new results to the previous one, and this will +#' be done automatically by the built-in callbacks such as \link{xgb.cb.evaluation.log}, +#' which will append the new rows to the previous table. +#' +#' If no such previous callback result is available (which it never will when fitting +#' a model from start instead of updating an existing model), this will be `NULL`. +#' +#' For \link{xgb.cv}, which doesn't support training continuation, this will always be `NULL`. +#' } +#' +#' The following names (`cb_name` values) are reserved for internal callbacks:\itemize{ +#' \item print_evaluation +#' \item evaluation_log +#' \item reset_parameters +#' \item early_stop +#' \item save_model +#' \item cv_predict +#' \item gblinear_history +#' } #' -#' Callback function expects the following values to be set in its calling frame: -#' \code{evaluation_log}, -#' \code{bst_evaluation}, -#' \code{iteration}. +#' The following names are reserved for other non-callback attributes:\itemize{ +#' \item names +#' \item class +#' \item call +#' \item params +#' \item niter +#' \item nfeatures +#' \item folds +#' } +#' +#' When using the built-in early stopping callback (\link{xgb.cb.early.stop}), said callback +#' will always be executed before the others, as it sets some booster C-level attributes +#' that other callbacks might also use. Otherwise, the order of execution will match with +#' the order in which the callbacks are passed to the model fitting function. +#' @seealso Built-in callbacks:\itemize{ +#' \item \link{xgb.cb.print.evaluation} +#' \item \link{xgb.cb.evaluation.log} +#' \item \link{xgb.cb.reset.parameters} +#' \item \link{xgb.cb.early.stop} +#' \item \link{xgb.cb.save.model} +#' \item \link{xgb.cb.cv.predict} +#' \item \link{xgb.cb.gblinear.history} +#' } +#' @examples +#' # Example constructing a custom callback that calculates +#' # squared error on the training data (no separate test set), +#' # and outputs the per-iteration results. +#' ssq_callback <- xgb.Callback( +#' cb_name = "ssq", +#' f_before_training = function(env, model, data, evals, +#' begin_iteration, end_iteration) { +#' # A vector to keep track of a number at each iteration +#' env$logs <- rep(NA_real_, end_iteration - begin_iteration + 1) +#' }, +#' f_after_iter = function(env, model, data, evals, iteration, iter_feval) { +#' # This calculates the sum of squared errors on the training data. +#' # Note that this can be better done by passing an 'evals' entry, +#' # but this demonstrates a way in which callbacks can be structured. +#' pred <- predict(model, data) +#' err <- pred - getinfo(data, "label") +#' sq_err <- sum(err^2) +#' env$logs[iteration] <- sq_err +#' cat( +#' sprintf( +#' "Squared error at iteration %d: %.2f\n", +#' iteration, sq_err +#' ) +#' ) +#' +#' # A return value of 'TRUE' here would signal to finalize the training +#' return(FALSE) +#' }, +#' f_after_training = function(env, model, data, evals, iteration, +#' final_feval, prev_cb_res) { +#' return(env$logs) +#' } +#' ) #' -#' @seealso -#' \code{\link{callbacks}} +#' data(mtcars) +#' y <- mtcars$mpg +#' x <- as.matrix(mtcars[, -1]) +#' dm <- xgb.DMatrix(x, label = y, nthread = 1) +#' model <- xgb.train( +#' data = dm, +#' params = list(objective = "reg:squarederror", nthread = 1), +#' nrounds = 5, +#' callbacks = list(ssq_callback), +#' keep_extra_attributes = TRUE +#' ) #' +#' # Result from 'f_after_iter' will be available as an attribute +#' attributes(model)$ssq #' @export -cb.evaluation.log <- function() { +xgb.Callback <- function( + cb_name = "custom_callback", + env = new.env(), + f_before_training = function(env, model, data, evals, begin_iteration, end_iteration) NULL, + f_before_iter = function(env, model, data, evals, iteration) NULL, + f_after_iter = function(env, model, data, evals, iteration, iter_feval) NULL, + f_after_training = function(env, model, data, evals, iteration, final_feval, prev_cb_res) NULL +) { + stopifnot(is.null(f_before_training) || is.function(f_before_training)) + stopifnot(is.null(f_before_iter) || is.function(f_before_iter)) + stopifnot(is.null(f_after_iter) || is.function(f_after_iter)) + stopifnot(is.null(f_after_training) || is.function(f_after_training)) + stopifnot(is.character(cb_name) && length(cb_name) == 1) + + if (cb_name %in% .reserved_cb_names) { + stop("Cannot use reserved callback name '", cb_name, "'.") + } - mnames <- NULL + out <- list( + cb_name = cb_name, + env = env, + f_before_training = f_before_training, + f_before_iter = f_before_iter, + f_after_iter = f_after_iter, + f_after_training = f_after_training + ) + class(out) <- "xgb.Callback" + return(out) +} - init <- function(env) { - if (!is.list(env$evaluation_log)) - stop("'evaluation_log' has to be a list") - mnames <<- names(env$bst_evaluation) - if (is.null(mnames) || any(mnames == "")) - stop("bst_evaluation must have non-empty names") +.execute.cb.before.training <- function( + callbacks, + model, + data, + evals, + begin_iteration, + end_iteration +) { + for (callback in callbacks) { + if (!is.null(callback$f_before_training)) { + callback$f_before_training( + callback$env, + model, + data, + evals, + begin_iteration, + end_iteration + ) + } + } +} - mnames <<- gsub('-', '_', names(env$bst_evaluation), fixed = TRUE) - if (!is.null(env$bst_evaluation_err)) - mnames <<- c(paste0(mnames, '_mean'), paste0(mnames, '_std')) +.execute.cb.before.iter <- function( + callbacks, + model, + data, + evals, + iteration +) { + if (!length(callbacks)) { + return(FALSE) } + out <- sapply(callbacks, function(cb) { + if (is.null(cb$f_before_iter)) { + return(FALSE) + } + should_stop <- cb$f_before_iter( + cb$env, + model, + data, + evals, + iteration + ) + if (!NROW(should_stop)) { + should_stop <- FALSE + } else if (NROW(should_stop) > 1) { + should_stop <- head(as.logical(should_stop), 1) + } + return(should_stop) + }) + return(any(out)) +} - finalizer <- function(env) { - env$evaluation_log <- as.data.table(t(simplify2array(env$evaluation_log))) - setnames(env$evaluation_log, c('iter', mnames)) - - if (!is.null(env$bst_evaluation_err)) { - # rearrange col order from _mean,_mean,...,_std,_std,... - # to be _mean,_std,_mean,_std,... - len <- length(mnames) - means <- mnames[seq_len(len / 2)] - stds <- mnames[(len / 2 + 1):len] - cnames <- numeric(len) - cnames[c(TRUE, FALSE)] <- means - cnames[c(FALSE, TRUE)] <- stds - env$evaluation_log <- env$evaluation_log[, c('iter', cnames), with = FALSE] +.execute.cb.after.iter <- function( + callbacks, + model, + data, + evals, + iteration, + iter_feval +) { + if (!length(callbacks)) { + return(FALSE) + } + out <- sapply(callbacks, function(cb) { + if (is.null(cb$f_after_iter)) { + return(FALSE) } + should_stop <- cb$f_after_iter( + cb$env, + model, + data, + evals, + iteration, + iter_feval + ) + if (!NROW(should_stop)) { + should_stop <- FALSE + } else if (NROW(should_stop) > 1) { + should_stop <- head(as.logical(should_stop), 1) + } + return(should_stop) + }) + return(any(out)) +} + +.execute.cb.after.training <- function( + callbacks, + model, + data, + evals, + iteration, + final_feval, + prev_cb_res +) { + if (!length(callbacks)) { + return(NULL) + } + old_cb_res <- attributes(model) + out <- lapply(callbacks, function(cb) { + if (is.null(cb$f_after_training)) { + return(NULL) + } else { + return( + cb$f_after_training( + cb$env, + model, + data, + evals, + iteration, + final_feval, + getElement(old_cb_res, cb$cb_name) + ) + ) + } + }) + names(out) <- sapply(callbacks, function(cb) cb$cb_name) + if (NROW(out)) { + out <- out[!sapply(out, is.null)] } + return(out) +} - callback <- function(env = parent.frame(), finalize = FALSE) { - if (is.null(mnames)) - init(env) +.summarize.feval <- function(iter_feval, showsd) { + if (NCOL(iter_feval) > 1L && showsd) { + stdev <- apply(iter_feval, 1, sd) + } else { + stdev <- NULL + } + if (NCOL(iter_feval) > 1L) { + iter_feval <- rowMeans(iter_feval) + } + return(list(feval = iter_feval, stdev = stdev)) +} - if (finalize) - return(finalizer(env)) +.print.evaluation <- function(iter_feval, showsd, iteration) { + tmp <- .summarize.feval(iter_feval, showsd) + msg <- .format_eval_string(iteration, tmp$feval, tmp$stdev) + cat(msg, '\n') +} - ev <- env$bst_evaluation - if (!is.null(env$bst_evaluation_err)) - ev <- c(ev, env$bst_evaluation_err) - env$evaluation_log <- c(env$evaluation_log, - list(c(iter = env$iteration, ev))) +# Format the evaluation metric string +.format_eval_string <- function(iter, eval_res, eval_err = NULL) { + if (length(eval_res) == 0) + stop('no evaluation results') + enames <- names(eval_res) + if (is.null(enames)) + stop('evaluation results must have names') + iter <- sprintf('[%d]\t', iter) + if (!is.null(eval_err)) { + if (length(eval_res) != length(eval_err)) + stop('eval_res & eval_err lengths mismatch') + # Note: UTF-8 code for plus/minus sign is U+00B1 + res <- paste0(sprintf("%s:%f\U00B1%f", enames, eval_res, eval_err), collapse = '\t') + } else { + res <- paste0(sprintf("%s:%f", enames, eval_res), collapse = '\t') } - attr(callback, 'call') <- match.call() - attr(callback, 'name') <- 'cb.evaluation.log' - callback + return(paste0(iter, res)) } -#' Callback closure for resetting the booster's parameters at each iteration. +#' @title Callback for printing the result of evaluation +#' @param period results would be printed every number of periods +#' @param showsd whether standard deviations should be printed (when available) +#' @return An `xgb.Callback` object, which can be passed to \link{xgb.train} or \link{xgb.cv}. +#' @description +#' The callback function prints the result of evaluation at every \code{period} iterations. +#' The initial and the last iteration's evaluations are always printed. #' +#' Does not leave any attribute in the booster (see \link{xgb.cb.evaluation.log} for that). +#' @seealso \link{xgb.Callback} +#' @export +xgb.cb.print.evaluation <- function(period = 1, showsd = TRUE) { + if (length(period) != 1 || period != floor(period) || period < 1) { + stop("'period' must be a positive integer.") + } + + xgb.Callback( + cb_name = "print_evaluation", + env = as.environment(list(period = period, showsd = showsd, is_first_call = TRUE)), + f_before_training = NULL, + f_before_iter = NULL, + f_after_iter = function(env, model, data, evals, iteration, iter_feval) { + if (is.null(iter_feval)) { + return(FALSE) + } + if (env$is_first_call || (iteration - 1) %% env$period == 0) { + .print.evaluation(iter_feval, env$showsd, iteration) + env$last_printed_iter <- iteration + } + env$is_first_call <- FALSE + return(FALSE) + }, + f_after_training = function(env, model, data, evals, iteration, final_feval, prev_cb_res) { + if (is.null(final_feval)) { + return(NULL) + } + if (is.null(env$last_printed_iter) || iteration > env$last_printed_iter) { + .print.evaluation(final_feval, env$showsd, iteration) + } + } + ) +} + +#' @title Callback for logging the evaluation history +#' @return An `xgb.Callback` object, which can be passed to \link{xgb.train} or \link{xgb.cv}. +#' @details This callback creates a table with per-iteration evaluation metrics (see parameters +#' `evals` and `feval` in \link{xgb.train}). +#' @details +#' Note: in the column names of the final data.table, the dash '-' character is replaced with +#' the underscore '_' in order to make the column names more like regular R identifiers. +#' @seealso \link{xgb.cb.print.evaluation} +#' @export +xgb.cb.evaluation.log <- function() { + xgb.Callback( + cb_name = "evaluation_log", + f_before_training = function(env, model, data, evals, begin_iteration, end_iteration) { + env$evaluation_log <- vector("list", end_iteration - begin_iteration + 1) + env$next_log <- 1 + }, + f_before_iter = NULL, + f_after_iter = function(env, model, data, evals, iteration, iter_feval) { + tmp <- .summarize.feval(iter_feval, TRUE) + env$evaluation_log[[env$next_log]] <- list(iter = iteration, metrics = tmp$feval, sds = tmp$stdev) + env$next_log <- env$next_log + 1 + return(FALSE) + }, + f_after_training = function(env, model, data, evals, iteration, final_feval, prev_cb_res) { + if (!NROW(env$evaluation_log)) { + return(prev_cb_res) + } + # in case of early stopping + if (env$next_log <= length(env$evaluation_log)) { + env$evaluation_log <- head(env$evaluation_log, env$next_log - 1) + } + + iters <- data.frame(iter = sapply(env$evaluation_log, function(x) x$iter)) + metrics <- do.call(rbind, lapply(env$evaluation_log, function(x) x$metrics)) + mnames <- gsub("-", "_", names(env$evaluation_log[[1]]$metrics), fixed = TRUE) + colnames(metrics) <- mnames + has_sds <- !is.null(env$evaluation_log[[1]]$sds) + if (has_sds) { + sds <- do.call(rbind, lapply(env$evaluation_log, function(x) x$sds)) + colnames(sds) <- mnames + metrics <- lapply( + mnames, + function(metric) { + out <- cbind(metrics[, metric], sds[, metric]) + colnames(out) <- paste0(metric, c("_mean", "_std")) + return(out) + } + ) + metrics <- do.call(cbind, metrics) + } + evaluation_log <- cbind(iters, metrics) + + if (!is.null(prev_cb_res)) { + if (!is.data.table(prev_cb_res)) { + prev_cb_res <- data.table::as.data.table(prev_cb_res) + } + prev_take <- prev_cb_res[prev_cb_res$iter < min(evaluation_log$iter)] + if (nrow(prev_take)) { + evaluation_log <- rbind(prev_cb_res, evaluation_log) + } + } + evaluation_log <- data.table::as.data.table(evaluation_log) + return(evaluation_log) + } + ) +} + +#' @title Callback for resetting the booster's parameters at each iteration. #' @param new_params a list where each element corresponds to a parameter that needs to be reset. #' Each element's value must be either a vector of values of length \code{nrounds} #' to be set at each iteration, #' or a function of two parameters \code{learning_rates(iteration, nrounds)} #' which returns a new parameter value by using the current iteration number #' and the total number of boosting rounds. -#' +#' @return An `xgb.Callback` object, which can be passed to \link{xgb.train} or \link{xgb.cv}. #' @details -#' This is a "pre-iteration" callback function used to reset booster's parameters -#' at the beginning of each iteration. -#' #' Note that when training is resumed from some previous model, and a function is used to #' reset a parameter value, the \code{nrounds} argument in this function would be the #' the number of boosting rounds in the current training. #' -#' Callback function expects the following values to be set in its calling frame: -#' \code{bst} or \code{bst_folds}, -#' \code{iteration}, -#' \code{begin_iteration}, -#' \code{end_iteration}. -#' -#' @seealso -#' \code{\link{callbacks}} -#' +#' Does not leave any attribute in the booster. #' @export -cb.reset.parameters <- function(new_params) { - - if (typeof(new_params) != "list") - stop("'new_params' must be a list") +xgb.cb.reset.parameters <- function(new_params) { + stopifnot(is.list(new_params)) pnames <- gsub(".", "_", names(new_params), fixed = TRUE) - nrounds <- NULL - - # run some checks in the beginning - init <- function(env) { - nrounds <<- env$end_iteration - env$begin_iteration + 1 - - if (is.null(env$bst) && is.null(env$bst_folds)) - stop("Parent frame has neither 'bst' nor 'bst_folds'") - - # Some parameters are not allowed to be changed, - # since changing them would simply wreck some chaos - not_allowed <- pnames %in% - c('num_class', 'num_output_group', 'size_leaf_vector', 'updater_seq') - if (any(not_allowed)) - stop('Parameters ', paste(pnames[not_allowed]), " cannot be changed during boosting.") - - for (n in pnames) { - p <- new_params[[n]] - if (is.function(p)) { - if (length(formals(p)) != 2) - stop("Parameter '", n, "' is a function but not of two arguments") - } else if (is.numeric(p) || is.character(p)) { - if (length(p) != nrounds) - stop("Length of '", n, "' has to be equal to 'nrounds'") - } else { - stop("Parameter '", n, "' is not a function or a vector") + not_allowed <- pnames %in% + c('num_class', 'num_output_group', 'size_leaf_vector', 'updater_seq') + if (any(not_allowed)) + stop('Parameters ', paste(pnames[not_allowed]), " cannot be changed during boosting.") + + xgb.Callback( + cb_name = "reset_parameters", + env = as.environment(list(new_params = new_params)), + f_before_training = function(env, model, data, evals, begin_iteration, end_iteration) { + env$end_iteration <- end_iteration + + pnames <- gsub(".", "_", names(env$new_params), fixed = TRUE) + for (n in pnames) { + p <- env$new_params[[n]] + if (is.function(p)) { + if (length(formals(p)) != 2) + stop("Parameter '", n, "' is a function but not of two arguments") + } else if (is.numeric(p) || is.character(p)) { + if (length(p) != env$end_iteration) + stop("Length of '", n, "' has to be equal to 'nrounds'") + } else { + stop("Parameter '", n, "' is not a function or a vector") + } } - } - } - - callback <- function(env = parent.frame()) { - if (is.null(nrounds)) - init(env) - - i <- env$iteration - pars <- lapply(new_params, function(p) { - if (is.function(p)) - return(p(i, nrounds)) - p[i] - }) + }, + f_before_iter = function(env, model, data, evals, iteration) { + pars <- lapply(env$new_params, function(p) { + if (is.function(p)) { + return(p(iteration, env$end_iteration)) + } else { + return(p[iteration]) + } + }) - if (!is.null(env$bst)) { - xgb.parameters(env$bst) <- pars - } else { - for (fd in env$bst_folds) - xgb.parameters(fd$bst) <- pars - } - } - attr(callback, 'is_pre_iteration') <- TRUE - attr(callback, 'call') <- match.call() - attr(callback, 'name') <- 'cb.reset.parameters' - callback + if (inherits(model, "xgb.Booster")) { + xgb.parameters(model) <- pars + } else { + for (fd in model) { + xgb.parameters(fd$bst) <- pars + } + } + return(FALSE) + }, + f_after_iter = NULL, + f_after_training = NULL + ) } - -#' Callback closure to activate the early stopping. -#' +#' @title Callback to activate early stopping #' @param stopping_rounds The number of rounds with no improvement in #' the evaluation metric in order to stop the training. -#' @param maximize whether to maximize the evaluation metric -#' @param metric_name the name of an evaluation column to use as a criteria for early +#' @param maximize Whether to maximize the evaluation metric. +#' @param metric_name The name of an evaluation column to use as a criteria for early #' stopping. If not set, the last column would be used. -#' Let's say the test data in \code{watchlist} was labelled as \code{dtest}, +#' Let's say the test data in \code{evals} was labelled as \code{dtest}, #' and one wants to use the AUC in test data for early stopping regardless of where -#' it is in the \code{watchlist}, then one of the following would need to be set: +#' it is in the \code{evals}, then one of the following would need to be set: #' \code{metric_name='dtest-auc'} or \code{metric_name='dtest_auc'}. #' All dash '-' characters in metric names are considered equivalent to '_'. -#' @param verbose whether to print the early stopping information. +#' @param verbose Whether to print the early stopping information. +#' @param keep_all_iter Whether to keep all of the boosting rounds that were produced +#' in the resulting object. If passing `FALSE`, will only keep the boosting rounds +#' up to the detected best iteration, discarding the ones that come after. +#' @return An `xgb.Callback` object, which can be passed to \link{xgb.train} or \link{xgb.cv}. +#' @description +#' This callback function determines the condition for early stopping. #' -#' @details -#' This callback function determines the condition for early stopping -#' by setting the \code{stop_condition = TRUE} flag in its calling frame. -#' -#' The following additional fields are assigned to the model's R object: +#' The following attributes are assigned to the booster's object: #' \itemize{ #' \item \code{best_score} the evaluation score at the best iteration -#' \item \code{best_iteration} at which boosting iteration the best score has occurred (1-based index) +#' \item \code{best_iteration} at which boosting iteration the best score has occurred +#' (0-based index for interoperability of binary models) #' } -#' The Same values are also stored as xgb-attributes: -#' \itemize{ -#' \item \code{best_iteration} is stored as a 0-based iteration index (for interoperability of binary models) -#' \item \code{best_msg} message string is also stored. -#' } -#' -#' At least one data element is required in the evaluation watchlist for early stopping to work. #' -#' Callback function expects the following values to be set in its calling frame: -#' \code{stop_condition}, -#' \code{bst_evaluation}, -#' \code{rank}, -#' \code{bst} (or \code{bst_folds} and \code{basket}), -#' \code{iteration}, -#' \code{begin_iteration}, -#' \code{end_iteration}, -#' -#' @seealso -#' \code{\link{callbacks}}, -#' \code{\link{xgb.attr}} +#' The same values are also stored as R attributes as a result of the callback, plus an additional +#' attribute `stopped_by_max_rounds` which indicates whether an early stopping by the `stopping_rounds` +#' condition occurred. Note that the `best_iteration` that is stored under R attributes will follow +#' base-1 indexing, so it will be larger by '1' than the C-level 'best_iteration' that is accessed +#' through \link{xgb.attr} or \link{xgb.attributes}. #' +#' At least one dataset is required in `evals` for early stopping to work. #' @export -cb.early.stop <- function(stopping_rounds, maximize = FALSE, - metric_name = NULL, verbose = TRUE) { - # state variables - best_iteration <- -1 - best_score <- Inf - best_msg <- NULL - metric_idx <- 1 - - init <- function(env) { - if (length(env$bst_evaluation) == 0) - stop("For early stopping, watchlist must have at least one element") - - eval_names <- gsub('-', '_', names(env$bst_evaluation), fixed = TRUE) - if (!is.null(metric_name)) { - metric_idx <<- which(gsub('-', '_', metric_name, fixed = TRUE) == eval_names) - if (length(metric_idx) == 0) - stop("'metric_name' for early stopping is not one of the following:\n", - paste(eval_names, collapse = ' '), '\n') - } - if (is.null(metric_name) && - length(env$bst_evaluation) > 1) { - metric_idx <<- length(eval_names) - if (verbose) - cat('Multiple eval metrics are present. Will use ', - eval_names[metric_idx], ' for early stopping.\n', sep = '') - } - - metric_name <<- eval_names[metric_idx] +xgb.cb.early.stop <- function( + stopping_rounds, + maximize = FALSE, + metric_name = NULL, + verbose = TRUE, + keep_all_iter = TRUE +) { + if (!is.null(metric_name)) { + stopifnot(is.character(metric_name)) + stopifnot(length(metric_name) == 1L) + } - # maximize is usually NULL when not set in xgb.train and built-in metrics - if (is.null(maximize)) - maximize <<- grepl('(_auc|_map|_ndcg|_pre)', metric_name) + xgb.Callback( + cb_name = "early_stop", + env = as.environment( + list( + checked_evnames = FALSE, + stopping_rounds = stopping_rounds, + maximize = maximize, + metric_name = metric_name, + verbose = verbose, + keep_all_iter = keep_all_iter, + stopped_by_max_rounds = FALSE + ) + ), + f_before_training = function(env, model, data, evals, begin_iteration, end_iteration) { + if (inherits(model, "xgb.Booster") && !length(evals)) { + stop("For early stopping, 'evals' must have at least one element") + } + env$begin_iteration <- begin_iteration + return(NULL) + }, + f_before_iter = function(env, model, data, evals, iteration) NULL, + f_after_iter = function(env, model, data, evals, iteration, iter_feval) { + sds <- NULL + if (NCOL(iter_feval) > 1) { + tmp <- .summarize.feval(iter_feval, TRUE) + iter_feval <- tmp$feval + sds <- tmp$stdev + } - if (verbose && NVL(env$rank, 0) == 0) - cat("Will train until ", metric_name, " hasn't improved in ", - stopping_rounds, " rounds.\n\n", sep = '') + if (!env$checked_evnames) { - best_iteration <<- 1 - if (maximize) best_score <<- -Inf + eval_names <- gsub('-', '_', names(iter_feval), fixed = TRUE) + if (!is.null(env$metric_name)) { + env$metric_idx <- which(gsub('-', '_', env$metric_name, fixed = TRUE) == eval_names) + if (length(env$metric_idx) == 0) + stop("'metric_name' for early stopping is not one of the following:\n", + paste(eval_names, collapse = ' '), '\n') + } - env$stop_condition <- FALSE + if (is.null(env$metric_name)) { + if (NROW(iter_feval) == 1) { + env$metric_idx <- 1L + } else { + env$metric_idx <- length(eval_names) + if (env$verbose) + cat('Multiple eval metrics are present. Will use ', + eval_names[env$metric_idx], ' for early stopping.\n', sep = '') + } + } - if (!is.null(env$bst)) { - if (!inherits(env$bst, 'xgb.Booster')) - stop("'bst' in the parent frame must be an 'xgb.Booster'") - if (!is.null(best_score <- xgb.attr(env$bst, 'best_score'))) { - best_score <<- as.numeric(best_score) - best_iteration <<- as.numeric(xgb.attr(env$bst, 'best_iteration')) + 1 - best_msg <<- as.numeric(xgb.attr(env$bst, 'best_msg')) - } else { - xgb.attributes(env$bst) <- list(best_iteration = best_iteration - 1, - best_score = best_score) - } - } else if (is.null(env$bst_folds) || is.null(env$basket)) { - stop("Parent frame has neither 'bst' nor ('bst_folds' and 'basket')") - } - } + env$metric_name <- eval_names[env$metric_idx] - finalizer <- function(env) { - if (!is.null(env$bst)) { - attr_best_score <- as.numeric(xgb.attr(env$bst, 'best_score')) - if (best_score != attr_best_score) { - # If the difference is too big, throw an error - if (abs(best_score - attr_best_score) >= 1e-14) { - stop("Inconsistent 'best_score' values between the closure state: ", best_score, - " and the xgb.attr: ", attr_best_score) - } - # If the difference is due to floating-point truncation, update best_score - best_score <- attr_best_score - } - xgb.attr(env$bst, "best_iteration") <- best_iteration - 1 - xgb.attr(env$bst, "best_score") <- best_score - } else { - env$basket$best_iteration <- best_iteration - } - } + # maximize is usually NULL when not set in xgb.train and built-in metrics + if (is.null(env$maximize)) + env$maximize <- grepl('(_auc|_aupr|_map|_ndcg|_pre)', env$metric_name) - callback <- function(env = parent.frame(), finalize = FALSE) { - if (best_iteration < 0) - init(env) + if (env$verbose) + cat("Will train until ", env$metric_name, " hasn't improved in ", + env$stopping_rounds, " rounds.\n\n", sep = '') - if (finalize) - return(finalizer(env)) + env$best_iteration <- env$begin_iteration + if (env$maximize) { + env$best_score <- -Inf + } else { + env$best_score <- Inf + } - i <- env$iteration - score <- env$bst_evaluation[metric_idx] + if (inherits(model, "xgb.Booster")) { + best_score <- xgb.attr(model, 'best_score') + if (NROW(best_score)) env$best_score <- as.numeric(best_score) + best_iteration <- xgb.attr(model, 'best_iteration') + if (NROW(best_iteration)) env$best_iteration <- as.numeric(best_iteration) + 1 + } - if ((maximize && score > best_score) || - (!maximize && score < best_score)) { + env$checked_evnames <- TRUE + } - best_msg <<- .format_eval_string( - i, env$bst_evaluation, env$bst_evaluation_err - ) - best_score <<- score - best_iteration <<- i - # save the property to attributes, so they will occur in checkpoint - if (!is.null(env$bst)) { - xgb.attributes(env$bst) <- list( - best_iteration = best_iteration - 1, # convert to 0-based index - best_score = best_score, - best_msg = best_msg - ) + score <- iter_feval[env$metric_idx] + if ((env$maximize && score > env$best_score) || + (!env$maximize && score < env$best_score)) { + + env$best_score <- score + env$best_iteration <- iteration + # save the property to attributes, so they will occur in checkpoint + if (inherits(model, "xgb.Booster")) { + xgb.attributes(model) <- list( + best_iteration = env$best_iteration - 1, # convert to 0-based index + best_score = env$best_score + ) + } + } else if (iteration - env$best_iteration >= env$stopping_rounds) { + if (env$verbose) { + best_msg <- .format_eval_string(iteration, iter_feval, sds) + cat("Stopping. Best iteration:\n", best_msg, "\n\n", sep = '') + } + env$stopped_by_max_rounds <- TRUE + return(TRUE) } - } else if (i - best_iteration >= stopping_rounds) { - env$stop_condition <- TRUE - env$end_iteration <- i - if (verbose && NVL(env$rank, 0) == 0) - cat("Stopping. Best iteration:\n", best_msg, "\n\n", sep = '') + return(FALSE) + }, + f_after_training = function(env, model, data, evals, iteration, final_feval, prev_cb_res) { + if (inherits(model, "xgb.Booster") && !env$keep_all_iter && env$best_iteration < iteration) { + # Note: it loses the attributes after being sliced, + # so they have to be re-assigned afterwards. + prev_attr <- xgb.attributes(model) + if (NROW(prev_attr)) { + suppressWarnings({ + prev_attr <- within(prev_attr, rm("best_score", "best_iteration")) + }) + } + .Call(XGBoosterSliceAndReplace_R, xgb.get.handle(model), 0L, env$best_iteration, 1L) + if (NROW(prev_attr)) { + xgb.attributes(model) <- prev_attr + } + } + attrs_set <- list(best_iteration = env$best_iteration - 1, best_score = env$best_score) + if (inherits(model, "xgb.Booster")) { + xgb.attributes(model) <- attrs_set + } else { + for (fd in model) { + xgb.attributes(fd$bst) <- attrs_set # to use in the cv.predict callback + } + } + return( + list( + best_iteration = env$best_iteration, + best_score = env$best_score, + stopped_by_max_rounds = env$stopped_by_max_rounds + ) + ) } - } - attr(callback, 'call') <- match.call() - attr(callback, 'name') <- 'cb.early.stop' - callback + ) } +.save.model.w.formatted.name <- function(model, save_name, iteration) { + # Note: this throws a warning if the name doesn't have anything to format through 'sprintf' + suppressWarnings({ + save_name <- sprintf(save_name, iteration) + }) + xgb.save(model, save_name) +} -#' Callback closure for saving a model file. -#' -#' @param save_period save the model to disk after every +#' @title Callback for saving a model file. +#' @param save_period Save the model to disk after every #' \code{save_period} iterations; 0 means save the model at the end. -#' @param save_name the name or path for the saved model file. -#' -#' Note that the format of the model being saved is determined by the file -#' extension specified here (see \link{xgb.save} for details about how it works). -#' +#' @param save_name The name or path for the saved model file. #' It can contain a \code{\link[base]{sprintf}} formatting specifier #' to include the integer iteration number in the file name. -#' E.g., with \code{save_name} = 'xgboost_%04d.ubj', -#' the file saved at iteration 50 would be named "xgboost_0050.ubj". -#' @seealso \link{xgb.save} -#' @details -#' This callback function allows to save an xgb-model file, either periodically after each \code{save_period}'s or at the end. -#' -#' Callback function expects the following values to be set in its calling frame: -#' \code{bst}, -#' \code{iteration}, -#' \code{begin_iteration}, -#' \code{end_iteration}. -#' -#' @seealso -#' \code{\link{callbacks}} +#' E.g., with \code{save_name} = 'xgboost_%04d.model', +#' the file saved at iteration 50 would be named "xgboost_0050.model". +#' @return An `xgb.Callback` object, which can be passed to \link{xgb.train}, +#' but \bold{not} to \link{xgb.cv}. +#' @description +#' This callback function allows to save an xgb-model file, either periodically +#' after each \code{save_period}'s or at the end. #' +#' Does not leave any attribute in the booster. #' @export -cb.save.model <- function(save_period = 0, save_name = "xgboost.ubj") { - - if (save_period < 0) +xgb.cb.save.model <- function(save_period = 0, save_name = "xgboost.ubj") { + if (save_period < 0) { stop("'save_period' cannot be negative") + } + if (!is.character(save_name) || length(save_name) != 1L) { + stop("'save_name' must be a single character refering to file name.") + } - callback <- function(env = parent.frame()) { - if (is.null(env$bst)) - stop("'save_model' callback requires the 'bst' booster object in its calling frame") - - if ((save_period > 0 && (env$iteration - env$begin_iteration) %% save_period == 0) || - (save_period == 0 && env$iteration == env$end_iteration)) { - # Note: this throws a warning if the name doesn't have anything to format through 'sprintf' - suppressWarnings({ - save_name <- sprintf(save_name, env$iteration) - }) - xgb.save(env$bst, save_name) + xgb.Callback( + cb_name = "save_model", + env = as.environment(list(save_period = save_period, save_name = save_name, last_save = 0)), + f_before_training = function(env, model, data, evals, begin_iteration, end_iteration) { + env$begin_iteration <- begin_iteration + }, + f_before_iter = NULL, + f_after_iter = function(env, model, data, evals, iteration, iter_feval) { + if (env$save_period > 0 && (iteration - env$begin_iteration) %% env$save_period == 0) { + .save.model.w.formatted.name(model, env$save_name, iteration) + env$last_save <- iteration + } + return(FALSE) + }, + f_after_training = function(env, model, data, evals, iteration, final_feval, prev_cb_res) { + if (env$save_period == 0 && iteration > env$last_save) { + .save.model.w.formatted.name(model, env$save_name, iteration) + } } - } - attr(callback, 'call') <- match.call() - attr(callback, 'name') <- 'cb.save.model' - callback + ) } - -#' Callback closure for returning cross-validation based predictions. -#' -#' @param save_models a flag for whether to save the folds' models. -#' -#' @details +#' @title Callback for returning cross-validation based predictions. +#' @param save_models A flag for whether to save the folds' models. +#' @param outputmargin Whether to save margin predictions (same effect as passing this +#' parameter to \link{predict.xgb.Booster}). +#' @return An `xgb.Callback` object, which can be passed to \link{xgb.cv}, +#' but \bold{not} to \link{xgb.train}. +#' @description #' This callback function saves predictions for all of the test folds, #' and also allows to save the folds' models. -#' -#' It is a "finalizer" callback and it uses early stopping information whenever it is available, -#' thus it must be run after the early stopping callback if the early stopping is used. -#' -#' Callback function expects the following values to be set in its calling frame: -#' \code{bst_folds}, -#' \code{basket}, -#' \code{data}, -#' \code{end_iteration}, -#' \code{params}, -#' -#' @return -#' Predictions are returned inside of the \code{pred} element, which is either a vector or a matrix, +#' @details +#' Predictions are saved inside of the \code{pred} element, which is either a vector or a matrix, #' depending on the number of prediction outputs per data row. The order of predictions corresponds #' to the order of rows in the original dataset. Note that when a custom \code{folds} list is #' provided in \code{xgb.cv}, the predictions would only be returned properly when this list is a @@ -480,84 +835,107 @@ cb.save.model <- function(save_period = 0, save_name = "xgboost.ubj") { #' meaningful when user-provided folds have overlapping indices as in, e.g., random sampling splits. #' When some of the indices in the training dataset are not included into user-provided \code{folds}, #' their prediction value would be \code{NA}. -#' -#' @seealso -#' \code{\link{callbacks}} -#' #' @export -cb.cv.predict <- function(save_models = FALSE) { - - finalizer <- function(env) { - if (is.null(env$basket) || is.null(env$bst_folds)) - stop("'cb.cv.predict' callback requires 'basket' and 'bst_folds' lists in its calling frame") - - N <- nrow(env$data) - pred <- NULL - - iterationrange <- c(1, NVL(env$basket$best_iteration, env$end_iteration)) - if (NVL(env$params[['booster']], '') == 'gblinear') { - iterationrange <- "all" - } - for (fd in env$bst_folds) { - pr <- predict(fd$bst, fd$watchlist[[2]], iterationrange = iterationrange, reshape = TRUE) - if (is.null(pred)) { - if (NCOL(pr) > 1L) { - pred <- matrix(NA_real_, N, ncol(pr)) +xgb.cb.cv.predict <- function(save_models = FALSE, outputmargin = FALSE) { + xgb.Callback( + cb_name = "cv_predict", + env = as.environment(list(save_models = save_models, outputmargin = outputmargin)), + f_before_training = function(env, model, data, evals, begin_iteration, end_iteration) { + if (inherits(model, "xgb.Booster")) { + stop("'cv.predict' callback is only for 'xgb.cv'.") + } + }, + f_before_iter = NULL, + f_after_iter = NULL, + f_after_training = function(env, model, data, evals, iteration, final_feval, prev_cb_res) { + pred <- NULL + for (fd in model) { + pr <- predict( + fd$bst, + fd$evals[[2L]], + outputmargin = env$outputmargin, + reshape = TRUE + ) + if (is.null(pred)) { + if (NCOL(pr) > 1L) { + pred <- matrix(NA_real_, nrow(data), ncol(pr)) + } else { + pred <- matrix(NA_real_, nrow(data)) + } + } + if (is.matrix(pred)) { + pred[fd$index, ] <- pr } else { - pred <- matrix(NA_real_, N) + pred[fd$index] <- pr } } - if (is.matrix(pred)) { - pred[fd$index, ] <- pr - } else { - pred[fd$index] <- pr + out <- list(pred = pred) + if (env$save_models) { + out$models <- lapply(model, function(fd) fd$bst) } + return(out) } - env$basket$pred <- pred - if (save_models) { - env$basket$models <- lapply(env$bst_folds, function(fd) { - return(fd$bst) - }) - } - } + ) +} - callback <- function(env = parent.frame(), finalize = FALSE) { - if (finalize) - return(finalizer(env)) +.list2mat <- function(coef_list, sparse) { + if (sparse) { + coef_mat <- methods::new("dgRMatrix") + coef_mat@p <- as.integer(c(0, cumsum(sapply(coef_list, function(x) length(x@x))))) + coef_mat@j <- as.integer(unlist(lapply(coef_list, slot, "i")) - 1L) + coef_mat@x <- unlist(lapply(coef_list, slot, "x")) + coef_mat@Dim <- as.integer(c(length(coef_list), length(coef_list[[1L]]))) + # Note: function 'xgb.gblinear.history' might later on try to slice by columns + coef_mat <- methods::as(coef_mat, "CsparseMatrix") + return(coef_mat) + } else { + return(unname(do.call(rbind, coef_list))) } - attr(callback, 'call') <- match.call() - attr(callback, 'name') <- 'cb.cv.predict' - callback } +.extract.coef <- function(model, sparse) { + coefs <- .internal.coef.xgb.Booster(model, add_names = FALSE) + if (NCOL(coefs) > 1L) { + coefs <- as.vector(coefs) + } + if (sparse) { + coefs <- methods::as(coefs, "sparseVector") + } + return(coefs) +} -#' Callback closure for collecting the model coefficients history of a gblinear booster -#' during its training. -#' -#' @param sparse when set to FALSE/TRUE, a dense/sparse matrix is used to store the result. +#' @title Callback for collecting coefficients history of a gblinear booster +#' @param sparse when set to `FALSE`/`TRUE`, a dense/sparse matrix is used to store the result. #' Sparse format is useful when one expects only a subset of coefficients to be non-zero, #' when using the "thrifty" feature selector with fairly small number of top features #' selected per iteration. -#' +#' @return An `xgb.Callback` object, which can be passed to \link{xgb.train} or \link{xgb.cv}. #' @details #' To keep things fast and simple, gblinear booster does not internally store the history of linear #' model coefficients at each boosting iteration. This callback provides a workaround for storing #' the coefficients' path, by extracting them after each training iteration. #' -#' Callback function expects the following values to be set in its calling frame: -#' \code{bst} (or \code{bst_folds}). +#' This callback will construct a matrix where rows are boosting iterations and columns are +#' feature coefficients (same order as when calling \link{coef.xgb.Booster}, with the intercept +#' corresponding to the first column). #' -#' @return -#' Results are stored in the \code{coefs} element of the closure. -#' The \code{\link{xgb.gblinear.history}} convenience function provides an easy -#' way to access it. -#' With \code{xgb.train}, it is either a dense of a sparse matrix. -#' While with \code{xgb.cv}, it is a list (an element per each fold) of such -#' matrices. +#' When there is more than one coefficient per feature (e.g. multi-class classification), +#' the result will be reshaped into a vector where coefficients are arranged first by features and +#' then by class (e.g. first 1 through N coefficients will be for the first class, then +#' coefficients N+1 through 2N for the second class, and so on). +#' +#' If the result has only one coefficient per feature in the data, then the resulting matrix +#' will have column names matching with the feature names, otherwise (when there's more than +#' one coefficient per feature) the names will be composed as 'column name' + ':' + 'class index' +#' (so e.g. column 'c1' for class '0' will be named 'c1:0'). #' -#' @seealso -#' \code{\link{callbacks}}, \code{\link{xgb.gblinear.history}}. +#' With \code{xgb.train}, the output is either a dense or a sparse matrix. +#' With with \code{xgb.cv}, it is a list (one element per each fold) of such +#' matrices. #' +#' Function \link{xgb.gblinear.history} function provides an easy way to retrieve the +#' outputs from this callback. +#' @seealso \link{xgb.gblinear.history}, \link{coef.xgb.Booster}. #' @examples #' #### Binary classification: #' @@ -577,7 +955,7 @@ cb.cv.predict <- function(save_models = FALSE) { #' # rate does not break the convergence, but allows us to illustrate the typical pattern of #' # "stochastic explosion" behaviour of this lock-free algorithm at early boosting iterations. #' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 1., -#' callbacks = list(cb.gblinear.history())) +#' callbacks = list(xgb.cb.gblinear.history())) #' # Extract the coefficients' path and plot them vs boosting iteration number: #' coef_path <- xgb.gblinear.history(bst) #' matplot(coef_path, type = 'l') @@ -586,7 +964,7 @@ cb.cv.predict <- function(save_models = FALSE) { #' # Will try the classical componentwise boosting which selects a single best feature per round: #' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 0.8, #' updater = 'coord_descent', feature_selector = 'thrifty', top_k = 1, -#' callbacks = list(cb.gblinear.history())) +#' callbacks = list(xgb.cb.gblinear.history())) #' matplot(xgb.gblinear.history(bst), type = 'l') #' # Componentwise boosting is known to have similar effect to Lasso regularization. #' # Try experimenting with various values of top_k, eta, nrounds, @@ -594,7 +972,7 @@ cb.cv.predict <- function(save_models = FALSE) { #' #' # For xgb.cv: #' bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 100, eta = 0.8, -#' callbacks = list(cb.gblinear.history())) +#' callbacks = list(xgb.cb.gblinear.history())) #' # coefficients in the CV fold #3 #' matplot(xgb.gblinear.history(bst)[[3]], type = 'l') #' @@ -607,7 +985,7 @@ cb.cv.predict <- function(save_models = FALSE) { #' # For the default linear updater 'shotgun' it sometimes is helpful #' # to use smaller eta to reduce instability #' bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 50, eta = 0.5, -#' callbacks = list(cb.gblinear.history())) +#' callbacks = list(xgb.cb.gblinear.history())) #' # Will plot the coefficient paths separately for each class: #' matplot(xgb.gblinear.history(bst, class_index = 0), type = 'l') #' matplot(xgb.gblinear.history(bst, class_index = 1), type = 'l') @@ -615,104 +993,141 @@ cb.cv.predict <- function(save_models = FALSE) { #' #' # CV: #' bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 70, eta = 0.5, -#' callbacks = list(cb.gblinear.history(FALSE))) +#' callbacks = list(xgb.cb.gblinear.history(FALSE))) #' # 1st fold of 1st class #' matplot(xgb.gblinear.history(bst, class_index = 0)[[1]], type = 'l') #' #' @export -cb.gblinear.history <- function(sparse = FALSE) { - coefs <- NULL - - init <- function(env) { - # xgb.train(): bst will be present - # xgb.cv(): bst_folds will be present - if (is.null(env$bst) && is.null(env$bst_folds)) { - stop("Parent frame has neither 'bst' nor 'bst_folds'") - } - } - - # convert from list to (sparse) matrix - list2mat <- function(coef_list) { - if (sparse) { - coef_mat <- sparseMatrix(x = unlist(lapply(coef_list, slot, "x")), - i = unlist(lapply(coef_list, slot, "i")), - p = c(0, cumsum(sapply(coef_list, function(x) length(x@x)))), - dims = c(length(coef_list[[1]]), length(coef_list))) - return(t(coef_mat)) - } else { - return(do.call(rbind, coef_list)) - } - } - - finalizer <- function(env) { - if (length(coefs) == 0) - return() - if (!is.null(env$bst)) { # # xgb.train: - coefs <<- list2mat(coefs) - } else { # xgb.cv: - # second lapply transposes the list - coefs <<- lapply( - X = lapply( - X = seq_along(coefs[[1]]), - FUN = function(i) lapply(coefs, "[[", i) - ), - FUN = list2mat - ) - } - } +xgb.cb.gblinear.history <- function(sparse = FALSE) { + xgb.Callback( + cb_name = "gblinear_history", + env = as.environment(list(sparse = sparse)), + f_before_training = function(env, model, data, evals, begin_iteration, end_iteration) { + if (!inherits(model, "xgb.Booster")) { + model <- model[[1L]]$bst + } + if (xgb.booster_type(model) != "gblinear") { + stop("Callback 'xgb.cb.gblinear.history' is only for booster='gblinear'.") + } + env$coef_hist <- vector("list", end_iteration - begin_iteration + 1) + env$next_idx <- 1 + }, + f_before_iter = NULL, + f_after_iter = function(env, model, data, evals, iteration, iter_feval) { + if (inherits(model, "xgb.Booster")) { + coef_this <- .extract.coef(model, env$sparse) + } else { + coef_this <- lapply(model, function(fd) .extract.coef(fd$bst, env$sparse)) + } + env$coef_hist[[env$next_idx]] <- coef_this + env$next_idx <- env$next_idx + 1 + return(FALSE) + }, + f_after_training = function(env, model, data, evals, iteration, final_feval, prev_cb_res) { + # in case of early stopping + if (env$next_idx <= length(env$coef_hist)) { + env$coef_hist <- head(env$coef_hist, env$next_idx - 1) + } - extract.coef <- function(env) { - if (!is.null(env$bst)) { # # xgb.train: - cf <- as.numeric(grep('(booster|bias|weigh)', xgb.dump(env$bst), invert = TRUE, value = TRUE)) - if (sparse) cf <- as(cf, "sparseVector") - } else { # xgb.cv: - cf <- vector("list", length(env$bst_folds)) - for (i in seq_along(env$bst_folds)) { - dmp <- xgb.dump(env$bst_folds[[i]]$bst) - cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE)) - if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector") + is_booster <- inherits(model, "xgb.Booster") + if (is_booster) { + out <- .list2mat(env$coef_hist, env$sparse) + } else { + out <- lapply( + X = lapply( + X = seq_along(env$coef_hist[[1]]), + FUN = function(i) lapply(env$coef_hist, "[[", i) + ), + FUN = .list2mat, + env$sparse + ) } + if (!is.null(prev_cb_res)) { + if (is_booster) { + out <- rbind(prev_cb_res, out) + } else { + # Note: this case should never be encountered, since training cannot + # be continued from the result of xgb.cv, but this code should in + # theory do the job if the situation were to be encountered. + out <- lapply( + out, + function(lst) { + lapply( + seq_along(lst), + function(i) rbind(prev_cb_res[[i]], lst[[i]]) + ) + } + ) + } + } + feature_names <- getinfo(data, "feature_name") + if (!NROW(feature_names)) { + feature_names <- paste0("V", seq(1L, ncol(data))) + } + expected_ncols <- length(feature_names) + 1 + if (is_booster) { + mat_ncols <- ncol(out) + } else { + mat_ncols <- ncol(out[[1L]]) + } + if (mat_ncols %% expected_ncols == 0) { + feature_names <- c("(Intercept)", feature_names) + n_rep <- mat_ncols / expected_ncols + if (n_rep > 1) { + feature_names <- unlist( + lapply( + seq(1, n_rep), + function(cl) paste(feature_names, cl - 1, sep = ":") + ) + ) + } + if (is_booster) { + colnames(out) <- feature_names + } else { + out <- lapply( + out, + function(mat) { + colnames(mat) <- feature_names + return(mat) + } + ) + } + } + return(out) } - cf - } - - callback <- function(env = parent.frame(), finalize = FALSE) { - if (is.null(coefs)) init(env) - if (finalize) return(finalizer(env)) - cf <- extract.coef(env) - coefs <<- c(coefs, list(cf)) - } - - attr(callback, 'call') <- match.call() - attr(callback, 'name') <- 'cb.gblinear.history' - callback + ) } #' @title Extract gblinear coefficients history. #' @description A helper function to extract the matrix of linear coefficients' history -#' from a gblinear model created while using the \code{cb.gblinear.history()} -#' callback. +#' from a gblinear model created while using the \link{xgb.cb.gblinear.history} +#' callback (which must be added manually as by default it's not used). #' @details Note that this is an R-specific function that relies on R attributes that #' are not saved when using xgboost's own serialization functions like \link{xgb.load} #' or \link{xgb.load.raw}. #' -#' In order for a serialized model to be accepted by tgis function, one must use R +#' In order for a serialized model to be accepted by this function, one must use R #' serializers such as \link{saveRDS}. #' @param model either an \code{xgb.Booster} or a result of \code{xgb.cv()}, trained -#' using the \code{cb.gblinear.history()} callback, but \bold{not} a booster +#' using the \link{xgb.cb.gblinear.history} callback, but \bold{not} a booster #' loaded from \link{xgb.load} or \link{xgb.load.raw}. #' @param class_index zero-based class index to extract the coefficients for only that #' specific class in a multinomial multiclass model. When it is NULL, all the #' coefficients are returned. Has no effect in non-multiclass models. #' #' @return -#' For an \code{xgb.train} result, a matrix (either dense or sparse) with the columns -#' corresponding to iteration's coefficients (in the order as \code{xgb.dump()} would -#' return) and the rows corresponding to boosting iterations. +#' For an \link{xgb.train} result, a matrix (either dense or sparse) with the columns +#' corresponding to iteration's coefficients and the rows corresponding to boosting iterations. #' -#' For an \code{xgb.cv} result, a list of such matrices is returned with the elements +#' For an \link{xgb.cv} result, a list of such matrices is returned with the elements #' corresponding to CV folds. #' +#' When there is more than one coefficient per feature (e.g. multi-class classification) +#' and `class_index` is not provided, +#' the result will be reshaped into a vector where coefficients are arranged first by features and +#' then by class (e.g. first 1 through N coefficients will be for the first class, then +#' coefficients N+1 through 2N for the second class, and so on). +#' @seealso \link{xgb.cb.gblinear.history}, \link{coef.xgb.Booster}. #' @export xgb.gblinear.history <- function(model, class_index = NULL) { @@ -721,14 +1136,14 @@ xgb.gblinear.history <- function(model, class_index = NULL) { stop("model must be an object of either xgb.Booster or xgb.cv.synchronous class") is_cv <- inherits(model, "xgb.cv.synchronous") - if (is_cv) { - callbacks <- model$callbacks + if (!is_cv) { + coef_path <- getElement(attributes(model), "gblinear_history") } else { - callbacks <- attributes(model)$callbacks + coef_path <- getElement(model, "gblinear_history") + } + if (is.null(coef_path)) { + stop("model must be trained while using the xgb.cb.gblinear.history() callback") } - - if (is.null(callbacks) || is.null(callbacks$cb.gblinear.history)) - stop("model must be trained while using the cb.gblinear.history() callback") if (!is_cv) { num_class <- xgb.num_class(model) @@ -748,105 +1163,82 @@ xgb.gblinear.history <- function(model, class_index = NULL) { (class_index[1] < 0 || class_index[1] >= num_class)) stop("class_index has to be within [0,", num_class - 1, "]") - coef_path <- environment(callbacks$cb.gblinear.history)[["coefs"]] if (!is.null(class_index) && num_class > 1) { + seq_take <- seq(1 + class_index * (num_feat + 1), (class_index + 1) * (num_feat + 1)) coef_path <- if (is.list(coef_path)) { - lapply(coef_path, - function(x) x[, seq(1 + class_index, by = num_class, length.out = num_feat)]) + lapply(coef_path, function(x) x[, seq_take]) } else { - coef_path <- coef_path[, seq(1 + class_index, by = num_class, length.out = num_feat)] + coef_path <- coef_path[, seq_take] } } - coef_path + return(coef_path) } +.callbacks.only.train <- "save_model" +.callbacks.only.cv <- "cv_predict" -# -# Internal utility functions for callbacks ------------------------------------ -# - -# Format the evaluation metric string -.format_eval_string <- function(iter, eval_res, eval_err = NULL) { - if (length(eval_res) == 0) - stop('no evaluation results') - enames <- names(eval_res) - if (is.null(enames)) - stop('evaluation results must have names') - iter <- sprintf('[%d]\t', iter) - if (!is.null(eval_err)) { - if (length(eval_res) != length(eval_err)) - stop('eval_res & eval_err lengths mismatch') - # Note: UTF-8 code for plus/minus sign is U+00B1 - res <- paste0(sprintf("%s:%f\U00B1%f", enames, eval_res, eval_err), collapse = '\t') - } else { - res <- paste0(sprintf("%s:%f", enames, eval_res), collapse = '\t') +.process.callbacks <- function(callbacks, is_cv) { + if (inherits(callbacks, "xgb.Callback")) { + callbacks <- list(callbacks) } - return(paste0(iter, res)) -} - -# Extract callback names from the list of callbacks -callback.names <- function(cb_list) { - unlist(lapply(cb_list, function(x) attr(x, 'name'))) -} - -# Extract callback calls from the list of callbacks -callback.calls <- function(cb_list) { - unlist(lapply(cb_list, function(x) attr(x, 'call'))) -} - -# Add a callback cb to the list and make sure that -# cb.early.stop and cb.cv.predict are at the end of the list -# with cb.cv.predict being the last (when present) -add.cb <- function(cb_list, cb) { - cb_list <- c(cb_list, cb) - names(cb_list) <- callback.names(cb_list) - if ('cb.early.stop' %in% names(cb_list)) { - cb_list <- c(cb_list, cb_list['cb.early.stop']) - # this removes only the first one - cb_list['cb.early.stop'] <- NULL + if (!is.list(callbacks)) { + stop("'callbacks' must be a list.") } - if ('cb.cv.predict' %in% names(cb_list)) { - cb_list <- c(cb_list, cb_list['cb.cv.predict']) - cb_list['cb.cv.predict'] <- NULL + cb_names <- character() + if (length(callbacks)) { + is_callback <- sapply(callbacks, inherits, "xgb.Callback") + if (!all(is_callback)) { + stop("Entries in 'callbacks' must be 'xgb.Callback' objects.") + } + cb_names <- sapply(callbacks, function(cb) cb$cb_name) + if (length(cb_names) != length(callbacks)) { + stop("Passed invalid callback(s).") + } + if (anyDuplicated(cb_names) > 0) { + stop("Callbacks must have unique names.") + } + if (is_cv) { + if (any(.callbacks.only.train %in% cb_names)) { + stop( + "Passed callback(s) not supported for 'xgb.cv': ", + paste(intersect(.callbacks.only.train, cb_names), collapse = ", ") + ) + } + } else { + if (any(.callbacks.only.cv %in% cb_names)) { + stop( + "Passed callback(s) not supported for 'xgb.train': ", + paste(intersect(.callbacks.only.cv, cb_names), collapse = ", ") + ) + } + } + # Early stopping callback needs to be executed before the others + if ("early_stop" %in% cb_names) { + mask <- cb_names == "early_stop" + callbacks <- c(list(callbacks[[which(mask)]]), callbacks[!mask]) + } } - cb_list + return(list(callbacks = callbacks, cb_names = cb_names)) } -# Sort callbacks list into categories -categorize.callbacks <- function(cb_list) { - list( - pre_iter = Filter(function(x) { - pre <- attr(x, 'is_pre_iteration') - !is.null(pre) && pre - }, cb_list), - post_iter = Filter(function(x) { - pre <- attr(x, 'is_pre_iteration') - is.null(pre) || !pre - }, cb_list), - finalize = Filter(function(x) { - 'finalize' %in% names(formals(x)) - }, cb_list) - ) +# Note: don't try to use functions like 'append', as they will +# merge the elements of the different callbacks into a single list. +add.callback <- function(callbacks, cb, as_first_elt = FALSE) { + if (!as_first_elt) { + callbacks[[length(callbacks) + 1]] <- cb + return(callbacks) + } else { + if (!length(callbacks)) { + return(list(cb)) + } + new_cb <- vector("list", length(callbacks) + 1) + new_cb[[1]] <- cb + new_cb[seq(2, length(new_cb))] <- callbacks + return(new_cb) + } } -# Check whether all callback functions with names given by 'query_names' are present in the 'cb_list'. -has.callbacks <- function(cb_list, query_names) { - if (length(cb_list) < length(query_names)) - return(FALSE) - if (!is.list(cb_list) || - any(sapply(cb_list, class) != 'function')) { - stop('`cb_list` must be a list of callback functions') - } - cb_names <- callback.names(cb_list) - if (!is.character(cb_names) || - length(cb_names) != length(cb_list) || - any(cb_names == "")) { - stop('All callbacks in the `cb_list` must have a non-empty `name` attribute') - } - if (!is.character(query_names) || - length(query_names) == 0 || - any(query_names == "")) { - stop('query_names must be a non-empty vector of non-empty character names') - } - return(all(query_names %in% cb_names)) +has.callbacks <- function(callbacks, cb_name) { + cb_names <- sapply(callbacks, function(cb) cb$name) + return(cb_name %in% cb_names) } diff --git a/R-package/R/utils.R b/R-package/R/utils.R index e8ae787fc722..7b6a20f704dd 100644 --- a/R-package/R/utils.R +++ b/R-package/R/utils.R @@ -26,6 +26,11 @@ NVL <- function(x, val) { 'multi:softprob', 'rank:pairwise', 'rank:ndcg', 'rank:map')) } +.RANKING_OBJECTIVES <- function() { + return(c('binary:logistic', 'binary:logitraw', 'binary:hinge', 'multi:softmax', + 'multi:softprob')) +} + # # Low-level functions for boosting -------------------------------------------- @@ -142,7 +147,7 @@ check.custom.eval <- function(env = parent.frame()) { if (!is.null(env$feval) && is.null(env$maximize) && ( !is.null(env$early_stopping_rounds) || - has.callbacks(env$callbacks, 'cb.early.stop'))) + has.callbacks(env$callbacks, "early_stop"))) stop("Please set 'maximize' to indicate whether the evaluation metric needs to be maximized or not") } @@ -193,20 +198,20 @@ xgb.iter.update <- function(bst, dtrain, iter, obj) { # Evaluate one iteration. # Returns a named vector of evaluation metrics # with the names in a 'datasetname-metricname' format. -xgb.iter.eval <- function(bst, watchlist, iter, feval) { +xgb.iter.eval <- function(bst, evals, iter, feval) { handle <- xgb.get.handle(bst) - if (length(watchlist) == 0) + if (length(evals) == 0) return(NULL) - evnames <- names(watchlist) + evnames <- names(evals) if (is.null(feval)) { - msg <- .Call(XGBoosterEvalOneIter_R, handle, as.integer(iter), watchlist, as.list(evnames)) + msg <- .Call(XGBoosterEvalOneIter_R, handle, as.integer(iter), evals, as.list(evnames)) mat <- matrix(strsplit(msg, '\\s+|:')[[1]][-1], nrow = 2) res <- structure(as.numeric(mat[2, ]), names = mat[1, ]) } else { - res <- sapply(seq_along(watchlist), function(j) { - w <- watchlist[[j]] + res <- sapply(seq_along(evals), function(j) { + w <- evals[[j]] ## predict using all trees preds <- predict(bst, w, outputmargin = TRUE, iterationrange = "all") eval_res <- feval(preds, w) @@ -235,33 +240,43 @@ convert.labels <- function(labels, objective_name) { } # Generates random (stratified if needed) CV folds -generate.cv.folds <- function(nfold, nrows, stratified, label, params) { +generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { + if (NROW(group)) { + if (stratified) { + warning( + paste0( + "Stratified splitting is not supported when using 'group' attribute.", + " Will use unstratified splitting." + ) + ) + } + return(generate.group.folds(nfold, group)) + } + objective <- params$objective + if (!is.character(objective)) { + warning("Will use unstratified splitting (custom objective used)") + stratified <- FALSE + } + # cannot stratify if label is NULL + if (stratified && is.null(label)) { + warning("Will use unstratified splitting (no 'labels' available)") + stratified <- FALSE + } # cannot do it for rank - objective <- params$objective if (is.character(objective) && strtrim(objective, 5) == 'rank:') { - stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n", + stop("\n\tAutomatic generation of CV-folds is not implemented for ranking without 'group' field!\n", "\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n") } # shuffle rnd_idx <- sample.int(nrows) - if (stratified && - length(label) == length(rnd_idx)) { + if (stratified && length(label) == length(rnd_idx)) { y <- label[rnd_idx] - # WARNING: some heuristic logic is employed to identify classification setting! # - For classification, need to convert y labels to factor before making the folds, # and then do stratification by factor levels. # - For regression, leave y numeric and do stratification by quantiles. if (is.character(objective)) { - y <- convert.labels(y, params$objective) - } else { - # If no 'objective' given in params, it means that user either wants to - # use the default 'reg:squarederror' objective or has provided a custom - # obj function. Here, assume classification setting when y has 5 or less - # unique values: - if (length(unique(y)) <= 5) { - y <- factor(y) - } + y <- convert.labels(y, objective) } folds <- xgb.createFolds(y = y, k = nfold) } else { @@ -277,6 +292,29 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) { return(folds) } +generate.group.folds <- function(nfold, group) { + ngroups <- length(group) - 1 + if (ngroups < nfold) { + stop("DMatrix has fewer groups than folds.") + } + seq_groups <- seq_len(ngroups) + indices <- lapply(seq_groups, function(gr) seq(group[gr] + 1, group[gr + 1])) + assignments <- base::split(seq_groups, as.integer(seq_groups %% nfold)) + assignments <- unname(assignments) + + out <- vector("list", nfold) + randomized_groups <- sample(ngroups) + for (idx in seq_len(nfold)) { + groups_idx_test <- randomized_groups[assignments[[idx]]] + groups_test <- indices[groups_idx_test] + idx_test <- unlist(groups_test) + attributes(idx_test)$group_test <- lengths(groups_test) + attributes(idx_test)$group_train <- lengths(indices[-groups_idx_test]) + out[[idx]] <- idx_test + } + return(out) +} + # Creates CV folds stratified by the values of y. # It was borrowed from caret::createFolds and simplified # by always returning an unnamed list of fold indices. @@ -454,7 +492,8 @@ depr_par_lut <- matrix(c( 'plot.height', 'plot_height', 'plot.width', 'plot_width', 'n_first_tree', 'trees', - 'dummy', 'DUMMY' + 'dummy', 'DUMMY', + 'watchlist', 'evals' ), ncol = 2, byrow = TRUE) colnames(depr_par_lut) <- c('old', 'new') diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 8a5d66198834..77d75fa9c2a5 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -1071,6 +1071,10 @@ xgb.best_iteration <- function(bst) { #' coef(model) #' @export coef.xgb.Booster <- function(object, ...) { + return(.internal.coef.xgb.Booster(object, add_names = TRUE)) +} + +.internal.coef.xgb.Booster <- function(object, add_names = TRUE) { booster_type <- xgb.booster_type(object) if (booster_type != "gblinear") { stop("Coefficients are not defined for Booster type ", booster_type) @@ -1089,21 +1093,27 @@ coef.xgb.Booster <- function(object, ...) { intercepts <- weights[seq(sep + 1, length(weights))] intercepts <- intercepts + as.numeric(base_score) - feature_names <- xgb.feature_names(object) - if (!NROW(feature_names)) { - # This mimics the default naming in R which names columns as "V1..N" - # when names are needed but not available - feature_names <- paste0("V", seq(1L, num_feature)) + if (add_names) { + feature_names <- xgb.feature_names(object) + if (!NROW(feature_names)) { + # This mimics the default naming in R which names columns as "V1..N" + # when names are needed but not available + feature_names <- paste0("V", seq(1L, num_feature)) + } + feature_names <- c("(Intercept)", feature_names) } - feature_names <- c("(Intercept)", feature_names) if (n_cols == 1L) { out <- c(intercepts, coefs) - names(out) <- feature_names + if (add_names) { + names(out) <- feature_names + } } else { coefs <- matrix(coefs, nrow = num_feature, byrow = TRUE) dim(intercepts) <- c(1L, n_cols) out <- rbind(intercepts, coefs) - row.names(out) <- feature_names + if (add_names) { + row.names(out) <- feature_names + } # TODO: if a class names attributes is added, # should use those names here. } @@ -1255,12 +1265,9 @@ print.xgb.Booster <- function(x, ...) { cat(" ", paste(attr_names, collapse = ", "), "\n") } - if (!is.null(R_attrs$callbacks) && length(R_attrs$callbacks) > 0) { - cat('callbacks:\n') - lapply(callback.calls(R_attrs$callbacks), function(x) { - cat(' ') - print(x) - }) + additional_attr <- setdiff(names(R_attrs), .reserved_cb_names) + if (NROW(additional_attr)) { + cat("callbacks:\n ", paste(additional_attr, collapse = ", "), "\n") } if (!is.null(R_attrs$evaluation_log)) { diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index edbc267c1067..15f6faed0ba0 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -1259,8 +1259,11 @@ xgb.get.DMatrix.data <- function(dmat) { #' Get a new DMatrix containing the specified rows of #' original xgb.DMatrix object #' -#' @param object Object of class "xgb.DMatrix" -#' @param idxset a integer vector of indices of rows needed +#' @param object Object of class "xgb.DMatrix". +#' @param idxset An integer vector of indices of rows needed (base-1 indexing). +#' @param allow_groups Whether to allow slicing an `xgb.DMatrix` with `group` (or +#' equivalently `qid`) field. Note that in such case, the result will not have +#' the groups anymore - they need to be set manually through `setinfo`. #' @param colset currently not used (columns subsetting is not available) #' #' @examples @@ -1275,11 +1278,11 @@ xgb.get.DMatrix.data <- function(dmat) { #' #' @rdname xgb.slice.DMatrix #' @export -xgb.slice.DMatrix <- function(object, idxset) { +xgb.slice.DMatrix <- function(object, idxset, allow_groups = FALSE) { if (!inherits(object, "xgb.DMatrix")) { stop("object must be xgb.DMatrix") } - ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset) + ret <- .Call(XGDMatrixSliceDMatrix_R, object, idxset, allow_groups) attr_list <- attributes(object) nr <- nrow(object) @@ -1296,7 +1299,15 @@ xgb.slice.DMatrix <- function(object, idxset) { } } } - return(structure(ret, class = "xgb.DMatrix")) + + out <- structure(ret, class = "xgb.DMatrix") + parent_fields <- as.list(attributes(object)$fields) + if (NROW(parent_fields)) { + child_fields <- parent_fields[!(names(parent_fields) %in% c("group", "qid"))] + child_fields <- as.environment(child_fields) + attributes(out)$fields <- child_fields + } + return(out) } #' @rdname xgb.slice.DMatrix @@ -1340,11 +1351,11 @@ print.xgb.DMatrix <- function(x, verbose = FALSE, ...) { } cat(class_print, ' dim:', nrow(x), 'x', ncol(x), ' info: ') - infos <- character(0) - if (xgb.DMatrix.hasinfo(x, 'label')) infos <- 'label' - if (xgb.DMatrix.hasinfo(x, 'weight')) infos <- c(infos, 'weight') - if (xgb.DMatrix.hasinfo(x, 'base_margin')) infos <- c(infos, 'base_margin') - if (length(infos) == 0) infos <- 'NA' + infos <- names(attributes(x)$fields) + infos <- infos[infos != "feature_name"] + if (!NROW(infos)) infos <- "NA" + infos <- infos[order(infos)] + infos <- paste(infos, collapse = ", ") cat(infos) cnames <- colnames(x) cat(' colnames:') diff --git a/R-package/R/xgb.create.features.R b/R-package/R/xgb.create.features.R index baef3bb03e28..27f8a0975ae7 100644 --- a/R-package/R/xgb.create.features.R +++ b/R-package/R/xgb.create.features.R @@ -71,7 +71,6 @@ #' new.dtest <- xgb.DMatrix( #' data = new.features.test, label = agaricus.test$label, nthread = 2 #' ) -#' watchlist <- list(train = new.dtrain) #' bst <- xgb.train(params = param, data = new.dtrain, nrounds = nrounds, nthread = 2) #' #' # Model accuracy with new features diff --git a/R-package/R/xgb.cv.R b/R-package/R/xgb.cv.R index 29bddb57f3e2..880fd56974bc 100644 --- a/R-package/R/xgb.cv.R +++ b/R-package/R/xgb.cv.R @@ -1,6 +1,6 @@ #' Cross Validation #' -#' The cross validation function of xgboost +#' The cross validation function of xgboost. #' #' @param params the list of parameters. The complete list of parameters is #' available in the \href{http://xgboost.readthedocs.io/en/latest/parameter.html}{online documentation}. Below @@ -19,15 +19,19 @@ #' #' See \code{\link{xgb.train}} for further details. #' See also demo/ for walkthrough example in R. -#' @param data takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input. +#' +#' Note that, while `params` accepts a `seed` entry and will use such parameter for model training if +#' supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG +#' system - thus, for reproducible results, one needs to call the `set.seed` function beforehand. +#' @param data An `xgb.DMatrix` object, with corresponding fields like `label` or bounds as required +#' for model training by the objective. +#' +#' Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix` +#' or `xgb.ExternalDMatrix` are not supported here. #' @param nrounds the max number of iterations #' @param nfold the original dataset is randomly partitioned into \code{nfold} equal size subsamples. -#' @param label vector of response values. Should be provided only when data is an R-matrix. -#' @param missing is only used when input is a dense matrix. By default is set to NA, which means -#' that NA values should be considered as 'missing' by the algorithm. -#' Sometimes, 0 or other extreme value might be used to represent missing values. #' @param prediction A logical value indicating whether to return the test fold predictions -#' from each CV model. This parameter engages the \code{\link{cb.cv.predict}} callback. +#' from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback. #' @param showsd \code{boolean}, whether to show standard deviation of cross validation #' @param metrics, list of evaluation metrics to be used in cross validation, #' when it is not specified, the evaluation metric is chosen according to objective function. @@ -47,27 +51,44 @@ #' @param feval customized evaluation function. Returns #' \code{list(metric='metric-name', value='metric-value')} with given #' prediction and dtrain. -#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified -#' by the values of outcome labels. +#' @param stratified A \code{boolean} indicating whether sampling of folds should be stratified +#' by the values of outcome labels. For real-valued labels in regression objectives, +#' stratification will be done by discretizing the labels into up to 5 buckets beforehand. +#' +#' If passing "auto", will be set to `TRUE` if the objective in `params` is a classification +#' objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to +#' `FALSE` otherwise. +#' +#' This parameter is ignored when `data` has a `group` field - in such case, the splitting +#' will be based on whole groups (note that this might make the folds have different sizes). +#' +#' Value `TRUE` here is \bold{not} supported for custom objectives. #' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds #' (each element must be a vector of test fold's indices). When folds are supplied, #' the \code{nfold} and \code{stratified} parameters are ignored. +#' +#' If `data` has a `group` field and the objective requires this field, each fold (list element) +#' must additionally have two attributes (retrievable through \link{attributes}) named `group_test` +#' and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to +#' the resulting DMatrices. #' @param train_folds \code{list} list specifying which indicies to use for training. If \code{NULL} #' (the default) all indices not specified in \code{folds} will be used for training. +#' +#' This is not supported when `data` has `group` field. #' @param verbose \code{boolean}, print the statistics during the process #' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}. #' Default is 1 which means all messages are printed. This parameter is passed to the -#' \code{\link{cb.print.evaluation}} callback. +#' \code{\link{xgb.cb.print.evaluation}} callback. #' @param early_stopping_rounds If \code{NULL}, the early stopping function is not triggered. #' If set to an integer \code{k}, training with a validation set will stop if the performance #' doesn't improve for \code{k} rounds. -#' Setting this parameter engages the \code{\link{cb.early.stop}} callback. +#' Setting this parameter engages the \code{\link{xgb.cb.early.stop}} callback. #' @param maximize If \code{feval} and \code{early_stopping_rounds} are set, #' then this parameter must be set as well. #' When it is \code{TRUE}, it means the larger the evaluation score the better. -#' This parameter is passed to the \code{\link{cb.early.stop}} callback. +#' This parameter is passed to the \code{\link{xgb.cb.early.stop}} callback. #' @param callbacks a list of callback functions to perform various task during boosting. -#' See \code{\link{callbacks}}. Some of the callbacks are automatically created depending on the +#' See \code{\link{xgb.Callback}}. Some of the callbacks are automatically created depending on the #' parameters' values. User can provide either existing or their own callback methods in order #' to customize the training process. #' @param ... other parameters to pass to \code{params}. @@ -90,25 +111,25 @@ #' \itemize{ #' \item \code{call} a function call. #' \item \code{params} parameters that were passed to the xgboost library. Note that it does not -#' capture parameters changed by the \code{\link{cb.reset.parameters}} callback. -#' \item \code{callbacks} callback functions that were either automatically assigned or -#' explicitly passed. +#' capture parameters changed by the \code{\link{xgb.cb.reset.parameters}} callback. #' \item \code{evaluation_log} evaluation history stored as a \code{data.table} with the #' first column corresponding to iteration number and the rest corresponding to the #' CV-based evaluation means and standard deviations for the training and test CV-sets. -#' It is created by the \code{\link{cb.evaluation.log}} callback. +#' It is created by the \code{\link{xgb.cb.evaluation.log}} callback. #' \item \code{niter} number of boosting iterations. #' \item \code{nfeatures} number of features in training data. #' \item \code{folds} the list of CV folds' indices - either those passed through the \code{folds} #' parameter or randomly generated. #' \item \code{best_iteration} iteration number with the best evaluation metric value #' (only available with early stopping). -#' \item \code{pred} CV prediction values available when \code{prediction} is set. -#' It is either vector or matrix (see \code{\link{cb.cv.predict}}). -#' \item \code{models} a list of the CV folds' models. It is only available with the explicit -#' setting of the \code{cb.cv.predict(save_models = TRUE)} callback. #' } #' +#' Plus other potential elements that are the result of callbacks, such as a list `cv_predict` with +#' a sub-element `pred` when passing `prediction = TRUE`, which is added by the \link{xgb.cb.cv.predict} +#' callback (note that one can also pass it manually under `callbacks` with different settings, +#' such as saving also the models created during cross validation); or a list `early_stop` which +#' will contain elements such as `best_iteration` when using the early stopping callback (\link{xgb.cb.early.stop}). +#' #' @examples #' data(agaricus.train, package='xgboost') #' dtrain <- with(agaricus.train, xgb.DMatrix(data, label = label, nthread = 2)) @@ -118,13 +139,14 @@ #' print(cv, verbose=TRUE) #' #' @export -xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing = NA, +xgb.cv <- function(params = list(), data, nrounds, nfold, prediction = FALSE, showsd = TRUE, metrics = list(), - obj = NULL, feval = NULL, stratified = TRUE, folds = NULL, train_folds = NULL, + obj = NULL, feval = NULL, stratified = "auto", folds = NULL, train_folds = NULL, verbose = TRUE, print_every_n = 1L, early_stopping_rounds = NULL, maximize = NULL, callbacks = list(), ...) { check.deprecation(...) + stopifnot(inherits(data, "xgb.DMatrix")) if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) { stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.") } @@ -137,16 +159,22 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing check.custom.obj() check.custom.eval() - # Check the labels - if ((inherits(data, 'xgb.DMatrix') && !xgb.DMatrix.hasinfo(data, 'label')) || - (!inherits(data, 'xgb.DMatrix') && is.null(label))) { - stop("Labels must be provided for CV either through xgb.DMatrix, or through 'label=' when 'data' is matrix") - } else if (inherits(data, 'xgb.DMatrix')) { - if (!is.null(label)) - warning("xgb.cv: label will be ignored, since data is of type xgb.DMatrix") - cv_label <- getinfo(data, 'label') - } else { - cv_label <- label + if (stratified == "auto") { + if (is.character(params$objective)) { + stratified <- ( + (params$objective %in% .CLASSIFICATION_OBJECTIVES()) + && !(params$objective %in% .RANKING_OBJECTIVES()) + ) + } else { + stratified <- FALSE + } + } + + # Check the labels and groups + cv_label <- getinfo(data, "label") + cv_group <- getinfo(data, "group") + if (!is.null(train_folds) && NROW(cv_group)) { + stop("'train_folds' is not supported for DMatrix object with 'group' field.") } # CV folds @@ -157,63 +185,64 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing } else { if (nfold <= 1) stop("'nfold' must be > 1") - folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, params) + folds <- generate.cv.folds(nfold, nrow(data), stratified, cv_label, cv_group, params) } + # Callbacks + tmp <- .process.callbacks(callbacks, is_cv = TRUE) + callbacks <- tmp$callbacks + cb_names <- tmp$cb_names + rm(tmp) + + # Early stopping callback + if (!is.null(early_stopping_rounds) && !("early_stop" %in% cb_names)) { + callbacks <- add.callback( + callbacks, + xgb.cb.early.stop( + early_stopping_rounds, + maximize = maximize, + verbose = verbose + ), + as_first_elt = TRUE + ) + } # verbosity & evaluation printing callback: params <- c(params, list(silent = 1)) print_every_n <- max(as.integer(print_every_n), 1L) - if (!has.callbacks(callbacks, 'cb.print.evaluation') && verbose) { - callbacks <- add.cb(callbacks, cb.print.evaluation(print_every_n, showsd = showsd)) + if (verbose && !("print_evaluation" %in% cb_names)) { + callbacks <- add.callback(callbacks, xgb.cb.print.evaluation(print_every_n, showsd = showsd)) } # evaluation log callback: always is on in CV - evaluation_log <- list() - if (!has.callbacks(callbacks, 'cb.evaluation.log')) { - callbacks <- add.cb(callbacks, cb.evaluation.log()) - } - # Early stopping callback - stop_condition <- FALSE - if (!is.null(early_stopping_rounds) && - !has.callbacks(callbacks, 'cb.early.stop')) { - callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, - maximize = maximize, verbose = verbose)) + if (!("evaluation_log" %in% cb_names)) { + callbacks <- add.callback(callbacks, xgb.cb.evaluation.log()) } # CV-predictions callback - if (prediction && - !has.callbacks(callbacks, 'cb.cv.predict')) { - callbacks <- add.cb(callbacks, cb.cv.predict(save_models = FALSE)) + if (prediction && !("cv_predict" %in% cb_names)) { + callbacks <- add.callback(callbacks, xgb.cb.cv.predict(save_models = FALSE)) } - # Sort the callbacks into categories - cb <- categorize.callbacks(callbacks) - # create the booster-folds # train_folds - dall <- xgb.get.DMatrix( - data = data, - label = label, - missing = missing, - weight = NULL, - nthread = params$nthread - ) + dall <- data bst_folds <- lapply(seq_along(folds), function(k) { - dtest <- xgb.slice.DMatrix(dall, folds[[k]]) + dtest <- xgb.slice.DMatrix(dall, folds[[k]], allow_groups = TRUE) # code originally contributed by @RolandASc on stackoverflow if (is.null(train_folds)) - dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k])) + dtrain <- xgb.slice.DMatrix(dall, unlist(folds[-k]), allow_groups = TRUE) else - dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]]) + dtrain <- xgb.slice.DMatrix(dall, train_folds[[k]], allow_groups = TRUE) + if (!is.null(attributes(folds[[k]])$group_test)) { + setinfo(dtest, "group", attributes(folds[[k]])$group_test) + setinfo(dtrain, "group", attributes(folds[[k]])$group_train) + } bst <- xgb.Booster( params = params, cachelist = list(dtrain, dtest), modelfile = NULL ) bst <- bst$bst - list(dtrain = dtrain, bst = bst, watchlist = list(train = dtrain, test = dtest), index = folds[[k]]) + list(dtrain = dtrain, bst = bst, evals = list(train = dtrain, test = dtest), index = folds[[k]]) }) - rm(dall) - # a "basket" to collect some results from callbacks - basket <- list() # extract parameters that can affect the relationship b/w #trees and #iterations num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1) # nolint @@ -222,10 +251,25 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing begin_iteration <- 1 end_iteration <- nrounds + .execute.cb.before.training( + callbacks, + bst_folds, + dall, + NULL, + begin_iteration, + end_iteration + ) + # synchronous CV boosting: run CV folds' models within each iteration for (iteration in begin_iteration:end_iteration) { - for (f in cb$pre_iter) f() + .execute.cb.before.iter( + callbacks, + bst_folds, + dall, + NULL, + iteration + ) msg <- lapply(bst_folds, function(fd) { xgb.iter.update( @@ -236,33 +280,42 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing ) xgb.iter.eval( bst = fd$bst, - watchlist = fd$watchlist, + evals = fd$evals, iter = iteration - 1, feval = feval ) }) msg <- simplify2array(msg) - # Note: these variables might look unused here, but they are used in the callbacks - bst_evaluation <- rowMeans(msg) # nolint - bst_evaluation_err <- apply(msg, 1, sd) # nolint - for (f in cb$post_iter) f() + should_stop <- .execute.cb.after.iter( + callbacks, + bst_folds, + dall, + NULL, + iteration, + msg + ) - if (stop_condition) break + if (should_stop) break } - for (f in cb$finalize) f(finalize = TRUE) + cb_outputs <- .execute.cb.after.training( + callbacks, + bst_folds, + dall, + NULL, + iteration, + msg + ) # the CV result ret <- list( call = match.call(), params = params, - callbacks = callbacks, - evaluation_log = evaluation_log, - niter = end_iteration, - nfeatures = ncol(data), + niter = iteration, + nfeatures = ncol(dall), folds = folds ) - ret <- c(ret, basket) + ret <- c(ret, cb_outputs) class(ret) <- 'xgb.cv.synchronous' return(invisible(ret)) @@ -285,8 +338,8 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing #' @examples #' data(agaricus.train, package='xgboost') #' train <- agaricus.train -#' cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2, -#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") +#' cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2, +#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") #' print(cv) #' print(cv, verbose=TRUE) #' @@ -308,23 +361,16 @@ print.xgb.cv.synchronous <- function(x, verbose = FALSE, ...) { paste0('"', unlist(x$params), '"'), sep = ' = ', collapse = ', '), '\n', sep = '') } - if (!is.null(x$callbacks) && length(x$callbacks) > 0) { - cat('callbacks:\n') - lapply(callback.calls(x$callbacks), function(x) { - cat(' ') - print(x) - }) - } for (n in c('niter', 'best_iteration')) { - if (is.null(x[[n]])) + if (is.null(x$early_stop[[n]])) next - cat(n, ': ', x[[n]], '\n', sep = '') + cat(n, ': ', x$early_stop[[n]], '\n', sep = '') } - if (!is.null(x$pred)) { + if (!is.null(x$cv_predict$pred)) { cat('pred:\n') - str(x$pred) + str(x$cv_predict$pred) } } @@ -332,9 +378,9 @@ print.xgb.cv.synchronous <- function(x, verbose = FALSE, ...) { cat('evaluation_log:\n') print(x$evaluation_log, row.names = FALSE, ...) - if (!is.null(x$best_iteration)) { + if (!is.null(x$early_stop$best_iteration)) { cat('Best iteration:\n') - print(x$evaluation_log[x$best_iteration], row.names = FALSE, ...) + print(x$evaluation_log[x$early_stop$best_iteration], row.names = FALSE, ...) } invisible(x) } diff --git a/R-package/R/xgb.load.R b/R-package/R/xgb.load.R index 4985f74b56c6..d5b192bcb6fa 100644 --- a/R-package/R/xgb.load.R +++ b/R-package/R/xgb.load.R @@ -6,7 +6,7 @@ #' #' @details #' The input file is expected to contain a model saved in an xgboost model format -#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some +#' using either \code{\link{xgb.save}} or \code{\link{xgb.cb.save.model}} in R, or using some #' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and #' saved from there in xgboost format, could be loaded from R. #' diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 44cde2e7a843..4cea088e0e45 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -114,13 +114,13 @@ #' @param data training dataset. \code{xgb.train} accepts only an \code{xgb.DMatrix} as the input. #' \code{xgboost}, in addition, also accepts \code{matrix}, \code{dgCMatrix}, or name of a local data file. #' @param nrounds max number of boosting iterations. -#' @param watchlist named list of xgb.DMatrix datasets to use for evaluating model performance. +#' @param evals Named list of `xgb.DMatrix` datasets to use for evaluating model performance. #' Metrics specified in either \code{eval_metric} or \code{feval} will be computed for each #' of these datasets during each boosting iteration, and stored in the end as a field named #' \code{evaluation_log} in the resulting object. When either \code{verbose>=1} or -#' \code{\link{cb.print.evaluation}} callback is engaged, the performance results are continuously +#' \code{\link{xgb.cb.print.evaluation}} callback is engaged, the performance results are continuously #' printed out during the training. -#' E.g., specifying \code{watchlist=list(validation1=mat1, validation2=mat2)} allows to track +#' E.g., specifying \code{evals=list(validation1=mat1, validation2=mat2)} allows to track #' the performance of each round's model on mat1 and mat2. #' @param obj customized objective function. Returns gradient and second order #' gradient with given prediction and dtrain. @@ -130,31 +130,32 @@ #' @param verbose If 0, xgboost will stay silent. If 1, it will print information about performance. #' If 2, some additional information will be printed out. #' Note that setting \code{verbose > 0} automatically engages the -#' \code{cb.print.evaluation(period=1)} callback function. +#' \code{xgb.cb.print.evaluation(period=1)} callback function. #' @param print_every_n Print each n-th iteration evaluation messages when \code{verbose>0}. #' Default is 1 which means all messages are printed. This parameter is passed to the -#' \code{\link{cb.print.evaluation}} callback. +#' \code{\link{xgb.cb.print.evaluation}} callback. #' @param early_stopping_rounds If \code{NULL}, the early stopping function is not triggered. #' If set to an integer \code{k}, training with a validation set will stop if the performance #' doesn't improve for \code{k} rounds. -#' Setting this parameter engages the \code{\link{cb.early.stop}} callback. +#' Setting this parameter engages the \code{\link{xgb.cb.early.stop}} callback. #' @param maximize If \code{feval} and \code{early_stopping_rounds} are set, #' then this parameter must be set as well. #' When it is \code{TRUE}, it means the larger the evaluation score the better. -#' This parameter is passed to the \code{\link{cb.early.stop}} callback. +#' This parameter is passed to the \code{\link{xgb.cb.early.stop}} callback. #' @param save_period when it is non-NULL, model is saved to disk after every \code{save_period} rounds, -#' 0 means save at the end. The saving is handled by the \code{\link{cb.save.model}} callback. +#' 0 means save at the end. The saving is handled by the \code{\link{xgb.cb.save.model}} callback. #' @param save_name the name or path for periodically saved model file. #' @param xgb_model a previously built model to continue the training from. #' Could be either an object of class \code{xgb.Booster}, or its raw data, or the name of a #' file with a previously saved model. #' @param callbacks a list of callback functions to perform various task during boosting. -#' See \code{\link{callbacks}}. Some of the callbacks are automatically created depending on the +#' See \code{\link{xgb.Callback}}. Some of the callbacks are automatically created depending on the #' parameters' values. User can provide either existing or their own callback methods in order #' to customize the training process. #' -#' Note that some callbacks might try to set an evaluation log - be aware that these evaluation logs -#' are kept as R attributes, and thus do not get saved when using non-R serializaters like +#' Note that some callbacks might try to leave attributes in the resulting model object, +#' such as an evaluation log (a `data.table` object) - be aware that these objects are kept +#' as R attributes, and thus do not get saved when using XGBoost's own serializaters like #' \link{xgb.save} (but are kept when using R serializers like \link{saveRDS}). #' @param ... other parameters to pass to \code{params}. #' @param label vector of response values. Should not be provided when data is @@ -170,7 +171,7 @@ #' @details #' These are the training functions for \code{xgboost}. #' -#' The \code{xgb.train} interface supports advanced features such as \code{watchlist}, +#' The \code{xgb.train} interface supports advanced features such as \code{evals}, #' customized objective and evaluation metric functions, therefore it is more flexible #' than the \code{xgboost} interface. #' @@ -206,18 +207,19 @@ #' #' The following callbacks are automatically created when certain parameters are set: #' \itemize{ -#' \item \code{cb.print.evaluation} is turned on when \code{verbose > 0}; +#' \item \code{xgb.cb.print.evaluation} is turned on when \code{verbose > 0}; #' and the \code{print_every_n} parameter is passed to it. -#' \item \code{cb.evaluation.log} is on when \code{watchlist} is present. -#' \item \code{cb.early.stop}: when \code{early_stopping_rounds} is set. -#' \item \code{cb.save.model}: when \code{save_period > 0} is set. +#' \item \code{xgb.cb.evaluation.log} is on when \code{evals} is present. +#' \item \code{xgb.cb.early.stop}: when \code{early_stopping_rounds} is set. +#' \item \code{xgb.cb.save.model}: when \code{save_period > 0} is set. #' } #' #' Note that objects of type `xgb.Booster` as returned by this function behave a bit differently #' from typical R objects (it's an 'altrep' list class), and it makes a separation between #' internal booster attributes (restricted to jsonifyable data), accessed through \link{xgb.attr} #' and shared between interfaces through serialization functions like \link{xgb.save}; and -#' R-specific attributes, accessed through \link{attributes} and \link{attr}, which are otherwise +#' R-specific attributes (typically the result from a callback), accessed through \link{attributes} +#' and \link{attr}, which are otherwise #' only used in the R interface, only kept when using R's serializers like \link{saveRDS}, and #' not anyhow used by functions like \link{predict.xgb.Booster}. #' @@ -229,7 +231,7 @@ #' effect elsewhere. #' #' @seealso -#' \code{\link{callbacks}}, +#' \code{\link{xgb.Callback}}, #' \code{\link{predict.xgb.Booster}}, #' \code{\link{xgb.cv}} #' @@ -252,12 +254,12 @@ #' dtest <- with( #' agaricus.test, xgb.DMatrix(data, label = label, nthread = nthread) #' ) -#' watchlist <- list(train = dtrain, eval = dtest) +#' evals <- list(train = dtrain, eval = dtest) #' #' ## A simple xgb.train example: #' param <- list(max_depth = 2, eta = 1, nthread = nthread, #' objective = "binary:logistic", eval_metric = "auc") -#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0) +#' bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0) #' #' ## An xgb.train example where custom objective and evaluation metric are #' ## used: @@ -278,15 +280,15 @@ #' # as 'objective' and 'eval_metric' parameters in the params list: #' param <- list(max_depth = 2, eta = 1, nthread = nthread, #' objective = logregobj, eval_metric = evalerror) -#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0) +#' bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0) #' #' # or through the ... arguments: #' param <- list(max_depth = 2, eta = 1, nthread = nthread) -#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, +#' bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, #' objective = logregobj, eval_metric = evalerror) #' #' # or as dedicated 'obj' and 'feval' parameters of xgb.train: -#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, +#' bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, #' obj = logregobj, feval = evalerror) #' #' @@ -294,11 +296,11 @@ #' param <- list(max_depth = 2, eta = 1, nthread = nthread, #' objective = "binary:logistic", eval_metric = "auc") #' my_etas <- list(eta = c(0.5, 0.1)) -#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, -#' callbacks = list(cb.reset.parameters(my_etas))) +#' bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, +#' callbacks = list(xgb.cb.reset.parameters(my_etas))) #' #' ## Early stopping: -#' bst <- xgb.train(param, dtrain, nrounds = 25, watchlist, +#' bst <- xgb.train(param, dtrain, nrounds = 25, evals = evals, #' early_stopping_rounds = 3) #' #' ## An 'xgboost' interface example: @@ -309,7 +311,7 @@ #' #' @rdname xgb.train #' @export -xgb.train <- function(params = list(), data, nrounds, watchlist = list(), +xgb.train <- function(params = list(), data, nrounds, evals = list(), obj = NULL, feval = NULL, verbose = 1, print_every_n = 1L, early_stopping_rounds = NULL, maximize = NULL, save_period = NULL, save_name = "xgboost.model", @@ -322,68 +324,68 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(), check.custom.obj() check.custom.eval() - # data & watchlist checks + # data & evals checks dtrain <- data if (!inherits(dtrain, "xgb.DMatrix")) stop("second argument dtrain must be xgb.DMatrix") - if (length(watchlist) > 0) { - if (typeof(watchlist) != "list" || - !all(vapply(watchlist, inherits, logical(1), what = 'xgb.DMatrix'))) - stop("watchlist must be a list of xgb.DMatrix elements") - evnames <- names(watchlist) + if (length(evals) > 0) { + if (typeof(evals) != "list" || + !all(vapply(evals, inherits, logical(1), what = 'xgb.DMatrix'))) + stop("'evals' must be a list of xgb.DMatrix elements") + evnames <- names(evals) if (is.null(evnames) || any(evnames == "")) - stop("each element of the watchlist must have a name tag") + stop("each element of 'evals' must have a name tag") } # Handle multiple evaluation metrics given as a list for (m in params$eval_metric) { params <- c(params, list(eval_metric = m)) } - # evaluation printing callback params <- c(params) - print_every_n <- max(as.integer(print_every_n), 1L) - if (!has.callbacks(callbacks, 'cb.print.evaluation') && - verbose) { - callbacks <- add.cb(callbacks, cb.print.evaluation(print_every_n)) + params['validate_parameters'] <- TRUE + if (!("seed" %in% names(params))) { + params[["seed"]] <- sample(.Machine$integer.max, size = 1) } - # evaluation log callback: it is automatically enabled when watchlist is provided - evaluation_log <- list() - if (!has.callbacks(callbacks, 'cb.evaluation.log') && - length(watchlist) > 0) { - callbacks <- add.cb(callbacks, cb.evaluation.log()) + + # callbacks + tmp <- .process.callbacks(callbacks, is_cv = FALSE) + callbacks <- tmp$callbacks + cb_names <- tmp$cb_names + rm(tmp) + + # Early stopping callback (should always come first) + if (!is.null(early_stopping_rounds) && !("early_stop" %in% cb_names)) { + callbacks <- add.callback( + callbacks, + xgb.cb.early.stop( + early_stopping_rounds, + maximize = maximize, + verbose = verbose + ), + as_first_elt = TRUE + ) } - # Model saving callback - if (!is.null(save_period) && - !has.callbacks(callbacks, 'cb.save.model')) { - callbacks <- add.cb(callbacks, cb.save.model(save_period, save_name)) + # evaluation printing callback + print_every_n <- max(as.integer(print_every_n), 1L) + if (verbose && !("print_evaluation" %in% cb_names)) { + callbacks <- add.callback(callbacks, xgb.cb.print.evaluation(print_every_n)) } - # Early stopping callback - stop_condition <- FALSE - if (!is.null(early_stopping_rounds) && - !has.callbacks(callbacks, 'cb.early.stop')) { - callbacks <- add.cb(callbacks, cb.early.stop(early_stopping_rounds, - maximize = maximize, verbose = verbose)) + # evaluation log callback: it is automatically enabled when 'evals' is provided + if (length(evals) && !("evaluation_log" %in% cb_names)) { + callbacks <- add.callback(callbacks, xgb.cb.evaluation.log()) } - - # Sort the callbacks into categories - cb <- categorize.callbacks(callbacks) - params['validate_parameters'] <- TRUE - if (!("seed" %in% names(params))) { - params[["seed"]] <- sample(.Machine$integer.max, size = 1) + # Model saving callback + if (!is.null(save_period) && !("save_model" %in% cb_names)) { + callbacks <- add.callback(callbacks, xgb.cb.save.model(save_period, save_name)) } # The tree updating process would need slightly different handling is_update <- NVL(params[['process_type']], '.') == 'update' - past_evaluation_log <- NULL - if (inherits(xgb_model, "xgb.Booster")) { - past_evaluation_log <- attributes(xgb_model)$evaluation_log - } - # Construct a booster (either a new one or load from xgb_model) bst <- xgb.Booster( params = params, - cachelist = append(watchlist, dtrain), + cachelist = append(evals, dtrain), modelfile = xgb_model ) niter_init <- bst$niter @@ -394,11 +396,6 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(), dtrain ) - # extract parameters that can affect the relationship b/w #trees and #iterations - # Note: it might look like these aren't used, but they need to be defined in this - # environment for the callbacks for work correctly. - num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1) # nolint - if (is_update && nrounds > niter_init) stop("nrounds cannot be larger than ", niter_init, " (nrounds of xgb_model)") @@ -406,57 +403,83 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(), begin_iteration <- niter_skip + 1 end_iteration <- niter_skip + nrounds + .execute.cb.before.training( + callbacks, + bst, + dtrain, + evals, + begin_iteration, + end_iteration + ) + # the main loop for boosting iterations for (iteration in begin_iteration:end_iteration) { - for (f in cb$pre_iter) f() + .execute.cb.before.iter( + callbacks, + bst, + dtrain, + evals, + iteration + ) xgb.iter.update( - bst = bst, - dtrain = dtrain, - iter = iteration - 1, - obj = obj + bst = bst, + dtrain = dtrain, + iter = iteration - 1, + obj = obj ) - if (length(watchlist) > 0) { - bst_evaluation <- xgb.iter.eval( # nolint: object_usage_linter + bst_evaluation <- NULL + if (length(evals) > 0) { + bst_evaluation <- xgb.iter.eval( bst = bst, - watchlist = watchlist, + evals = evals, iter = iteration - 1, feval = feval ) } - for (f in cb$post_iter) f() + should_stop <- .execute.cb.after.iter( + callbacks, + bst, + dtrain, + evals, + iteration, + bst_evaluation + ) - if (stop_condition) break + if (should_stop) break } - for (f in cb$finalize) f(finalize = TRUE) - # store the evaluation results - keep_evaluation_log <- FALSE - if (length(evaluation_log) > 0 && nrow(evaluation_log) > 0) { - keep_evaluation_log <- TRUE - # include the previous compatible history when available - if (inherits(xgb_model, 'xgb.Booster') && - !is_update && - !is.null(past_evaluation_log) && - isTRUE(all.equal(colnames(evaluation_log), - colnames(past_evaluation_log)))) { - evaluation_log <- rbindlist(list(past_evaluation_log, evaluation_log)) - } - } + cb_outputs <- .execute.cb.after.training( + callbacks, + bst, + dtrain, + evals, + iteration, + bst_evaluation + ) extra_attrs <- list( call = match.call(), - params = params, - callbacks = callbacks + params = params ) - if (keep_evaluation_log) { - extra_attrs$evaluation_log <- evaluation_log - } + curr_attrs <- attributes(bst) - attributes(bst) <- c(curr_attrs, extra_attrs) + if (NROW(curr_attrs)) { + curr_attrs <- curr_attrs[ + setdiff( + names(curr_attrs), + c(names(extra_attrs), names(cb_outputs)) + ) + ] + } + curr_attrs <- c(extra_attrs, curr_attrs) + if (NROW(cb_outputs)) { + curr_attrs <- c(curr_attrs, cb_outputs) + } + attributes(bst) <- curr_attrs return(bst) } diff --git a/R-package/R/xgboost.R b/R-package/R/xgboost.R index 170aa5ffd5be..a1d37358162c 100644 --- a/R-package/R/xgboost.R +++ b/R-package/R/xgboost.R @@ -18,9 +18,9 @@ xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL, nthread = merged$nthread ) - watchlist <- list(train = dtrain) + evals <- list(train = dtrain) - bst <- xgb.train(params, dtrain, nrounds, watchlist, verbose = verbose, print_every_n = print_every_n, + bst <- xgb.train(params, dtrain, nrounds, evals, verbose = verbose, print_every_n = print_every_n, early_stopping_rounds = early_stopping_rounds, maximize = maximize, save_period = save_period, save_name = save_name, xgb_model = xgb_model, callbacks = callbacks, ...) @@ -82,12 +82,8 @@ NULL NULL # Various imports -#' @importClassesFrom Matrix dgCMatrix dgeMatrix dgRMatrix -#' @importFrom Matrix colSums +#' @importClassesFrom Matrix dgCMatrix dgRMatrix CsparseMatrix #' @importFrom Matrix sparse.model.matrix -#' @importFrom Matrix sparseVector -#' @importFrom Matrix sparseMatrix -#' @importFrom Matrix t #' @importFrom data.table data.table #' @importFrom data.table is.data.table #' @importFrom data.table as.data.table @@ -103,6 +99,7 @@ NULL #' @importFrom stats coef #' @importFrom stats predict #' @importFrom stats median +#' @importFrom stats sd #' @importFrom stats variable.names #' @importFrom utils head #' @importFrom graphics barplot diff --git a/R-package/demo/basic_walkthrough.R b/R-package/demo/basic_walkthrough.R index 3dbbe0586f44..9403bac2064c 100644 --- a/R-package/demo/basic_walkthrough.R +++ b/R-package/demo/basic_walkthrough.R @@ -74,17 +74,17 @@ print(paste("sum(abs(pred3-pred))=", sum(abs(pred3 - pred)))) # to use advanced features, we need to put data in xgb.DMatrix dtrain <- xgb.DMatrix(data = train$data, label = train$label) dtest <- xgb.DMatrix(data = test$data, label = test$label) -#---------------Using watchlist---------------- -# watchlist is a list of xgb.DMatrix, each of them is tagged with name -watchlist <- list(train = dtrain, test = dtest) -# to train with watchlist, use xgb.train, which contains more advanced features -# watchlist allows us to monitor the evaluation result on all data in the list -print("Train xgboost using xgb.train with watchlist") -bst <- xgb.train(data = dtrain, max_depth = 2, eta = 1, nrounds = 2, watchlist = watchlist, +#---------------Using an evaluation set---------------- +# 'evals' is a list of xgb.DMatrix, each of them is tagged with name +evals <- list(train = dtrain, test = dtest) +# to train with an evaluation set, use xgb.train, which contains more advanced features +# 'evals' argument allows us to monitor the evaluation result on all data in the list +print("Train xgboost using xgb.train with evaluation data") +bst <- xgb.train(data = dtrain, max_depth = 2, eta = 1, nrounds = 2, evals = evals, nthread = 2, objective = "binary:logistic") # we can change evaluation metrics, or use multiple evaluation metrics -print("train xgboost using xgb.train with watchlist, watch logloss and error") -bst <- xgb.train(data = dtrain, max_depth = 2, eta = 1, nrounds = 2, watchlist = watchlist, +print("train xgboost using xgb.train with evaluation data, watch logloss and error") +bst <- xgb.train(data = dtrain, max_depth = 2, eta = 1, nrounds = 2, evals = evals, eval_metric = "error", eval_metric = "logloss", nthread = 2, objective = "binary:logistic") @@ -92,7 +92,7 @@ bst <- xgb.train(data = dtrain, max_depth = 2, eta = 1, nrounds = 2, watchlist = xgb.DMatrix.save(dtrain, "dtrain.buffer") # to load it in, simply call xgb.DMatrix dtrain2 <- xgb.DMatrix("dtrain.buffer") -bst <- xgb.train(data = dtrain2, max_depth = 2, eta = 1, nrounds = 2, watchlist = watchlist, +bst <- xgb.train(data = dtrain2, max_depth = 2, eta = 1, nrounds = 2, evals = evals, nthread = 2, objective = "binary:logistic") # information can be extracted from xgb.DMatrix using getinfo label <- getinfo(dtest, "label") diff --git a/R-package/demo/boost_from_prediction.R b/R-package/demo/boost_from_prediction.R index 1a3d55369d2f..75af70dba0d7 100644 --- a/R-package/demo/boost_from_prediction.R +++ b/R-package/demo/boost_from_prediction.R @@ -5,14 +5,14 @@ data(agaricus.test, package = 'xgboost') dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) -watchlist <- list(eval = dtest, train = dtrain) +evals <- list(eval = dtest, train = dtrain) ### # advanced: start from a initial base prediction # print('start running example to start from a initial prediction') # train xgboost for 1 round param <- list(max_depth = 2, eta = 1, nthread = 2, objective = 'binary:logistic') -bst <- xgb.train(param, dtrain, 1, watchlist) +bst <- xgb.train(param, dtrain, 1, evals) # Note: we need the margin value instead of transformed prediction in set_base_margin # do predict with output_margin=TRUE, will always give you margin values before logistic transformation ptrain <- predict(bst, dtrain, outputmargin = TRUE) @@ -23,4 +23,4 @@ setinfo(dtrain, "base_margin", ptrain) setinfo(dtest, "base_margin", ptest) print('this is result of boost from initial prediction') -bst <- xgb.train(params = param, data = dtrain, nrounds = 1, watchlist = watchlist) +bst <- xgb.train(params = param, data = dtrain, nrounds = 1, evals = evals) diff --git a/R-package/demo/custom_objective.R b/R-package/demo/custom_objective.R index 35201332c5f6..03d7b346471b 100644 --- a/R-package/demo/custom_objective.R +++ b/R-package/demo/custom_objective.R @@ -8,7 +8,7 @@ dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) # note: for customized objective function, we leave objective as default # note: what we are getting is margin value in prediction # you must know what you are doing -watchlist <- list(eval = dtest, train = dtrain) +evals <- list(eval = dtest, train = dtrain) num_round <- 2 # user define objective function, given prediction, return gradient and second order gradient @@ -38,7 +38,7 @@ param <- list(max_depth = 2, eta = 1, nthread = 2, verbosity = 0, print('start training with user customized objective') # training with customized objective, we can also do step by step training # simply look at xgboost.py's implementation of train -bst <- xgb.train(param, dtrain, num_round, watchlist) +bst <- xgb.train(param, dtrain, num_round, evals) # # there can be cases where you want additional information @@ -62,4 +62,4 @@ param <- list(max_depth = 2, eta = 1, nthread = 2, verbosity = 0, print('start training with user customized objective, with additional attributes in DMatrix') # training with customized objective, we can also do step by step training # simply look at xgboost.py's implementation of train -bst <- xgb.train(param, dtrain, num_round, watchlist) +bst <- xgb.train(param, dtrain, num_round, evals) diff --git a/R-package/demo/early_stopping.R b/R-package/demo/early_stopping.R index 04da1382f031..057440882567 100644 --- a/R-package/demo/early_stopping.R +++ b/R-package/demo/early_stopping.R @@ -8,7 +8,7 @@ dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) # note: what we are getting is margin value in prediction # you must know what you are doing param <- list(max_depth = 2, eta = 1, nthread = 2, verbosity = 0) -watchlist <- list(eval = dtest) +evals <- list(eval = dtest) num_round <- 20 # user define objective function, given prediction, return gradient and second order gradient # this is log likelihood loss @@ -32,7 +32,7 @@ evalerror <- function(preds, dtrain) { } print('start training with early Stopping setting') -bst <- xgb.train(param, dtrain, num_round, watchlist, +bst <- xgb.train(param, dtrain, num_round, evals, objective = logregobj, eval_metric = evalerror, maximize = FALSE, early_stopping_round = 3) bst <- xgb.cv(param, dtrain, num_round, nfold = 5, diff --git a/R-package/demo/generalized_linear_model.R b/R-package/demo/generalized_linear_model.R index c24fe72cbcad..d29a6dc5be58 100644 --- a/R-package/demo/generalized_linear_model.R +++ b/R-package/demo/generalized_linear_model.R @@ -25,9 +25,9 @@ param <- list(objective = "binary:logistic", booster = "gblinear", ## # the rest of settings are the same ## -watchlist <- list(eval = dtest, train = dtrain) +evals <- list(eval = dtest, train = dtrain) num_round <- 2 -bst <- xgb.train(param, dtrain, num_round, watchlist) +bst <- xgb.train(param, dtrain, num_round, evals) ypred <- predict(bst, dtest) labels <- getinfo(dtest, 'label') cat('error of preds=', mean(as.numeric(ypred > 0.5) != labels), '\n') diff --git a/R-package/demo/gpu_accelerated.R b/R-package/demo/gpu_accelerated.R index 14ed9392b7d1..617a63e74542 100644 --- a/R-package/demo/gpu_accelerated.R +++ b/R-package/demo/gpu_accelerated.R @@ -23,7 +23,7 @@ y <- rbinom(N, 1, plogis(m)) tr <- sample.int(N, N * 0.75) dtrain <- xgb.DMatrix(X[tr, ], label = y[tr]) dtest <- xgb.DMatrix(X[-tr, ], label = y[-tr]) -wl <- list(train = dtrain, test = dtest) +evals <- list(train = dtrain, test = dtest) # An example of running 'gpu_hist' algorithm # which is @@ -35,11 +35,11 @@ wl <- list(train = dtrain, test = dtest) param <- list(objective = 'reg:logistic', eval_metric = 'auc', subsample = 0.5, nthread = 4, max_bin = 64, tree_method = 'gpu_hist') pt <- proc.time() -bst_gpu <- xgb.train(param, dtrain, watchlist = wl, nrounds = 50) +bst_gpu <- xgb.train(param, dtrain, evals = evals, nrounds = 50) proc.time() - pt # Compare to the 'hist' algorithm: param$tree_method <- 'hist' pt <- proc.time() -bst_hist <- xgb.train(param, dtrain, watchlist = wl, nrounds = 50) +bst_hist <- xgb.train(param, dtrain, evals = evals, nrounds = 50) proc.time() - pt diff --git a/R-package/demo/predict_first_ntree.R b/R-package/demo/predict_first_ntree.R index 179c18c707f4..ba15ab39a74f 100644 --- a/R-package/demo/predict_first_ntree.R +++ b/R-package/demo/predict_first_ntree.R @@ -6,11 +6,11 @@ dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) param <- list(max_depth = 2, eta = 1, objective = 'binary:logistic') -watchlist <- list(eval = dtest, train = dtrain) +evals <- list(eval = dtest, train = dtrain) nrounds <- 2 # training the model for two rounds -bst <- xgb.train(param, dtrain, nrounds, nthread = 2, watchlist) +bst <- xgb.train(param, dtrain, nrounds, nthread = 2, evals = evals) cat('start testing prediction from first n trees\n') labels <- getinfo(dtest, 'label') diff --git a/R-package/demo/predict_leaf_indices.R b/R-package/demo/predict_leaf_indices.R index 21b6fa71d0b7..a57baf668896 100644 --- a/R-package/demo/predict_leaf_indices.R +++ b/R-package/demo/predict_leaf_indices.R @@ -43,7 +43,6 @@ colnames(new.features.test) <- colnames(new.features.train) # learning with new features new.dtrain <- xgb.DMatrix(data = new.features.train, label = agaricus.train$label) new.dtest <- xgb.DMatrix(data = new.features.test, label = agaricus.test$label) -watchlist <- list(train = new.dtrain) bst <- xgb.train(params = param, data = new.dtrain, nrounds = nrounds, nthread = 2) # Model accuracy with new features diff --git a/R-package/demo/tweedie_regression.R b/R-package/demo/tweedie_regression.R index dfaf6a2ae2ce..b07858e761fa 100644 --- a/R-package/demo/tweedie_regression.R +++ b/R-package/demo/tweedie_regression.R @@ -39,7 +39,7 @@ bst <- xgb.train( data = d_train, params = params, maximize = FALSE, - watchlist = list(train = d_train), + evals = list(train = d_train), nrounds = 20) var_imp <- xgb.importance(attr(x, 'Dimnames')[[2]], model = bst) diff --git a/R-package/man/callbacks.Rd b/R-package/man/callbacks.Rd deleted file mode 100644 index 9f6f69015dcb..000000000000 --- a/R-package/man/callbacks.Rd +++ /dev/null @@ -1,37 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/callbacks.R -\name{callbacks} -\alias{callbacks} -\title{Callback closures for booster training.} -\description{ -These are used to perform various service tasks either during boosting iterations or at the end. -This approach helps to modularize many of such tasks without bloating the main training methods, -and it offers . -} -\details{ -By default, a callback function is run after each boosting iteration. -An R-attribute \code{is_pre_iteration} could be set for a callback to define a pre-iteration function. - -When a callback function has \code{finalize} parameter, its finalizer part will also be run after -the boosting is completed. - -WARNING: side-effects!!! Be aware that these callback functions access and modify things in -the environment from which they are called from, which is a fairly uncommon thing to do in R. - -To write a custom callback closure, make sure you first understand the main concepts about R environments. -Check either R documentation on \code{\link[base]{environment}} or the -\href{http://adv-r.had.co.nz/Environments.html}{Environments chapter} from the "Advanced R" -book by Hadley Wickham. Further, the best option is to read the code of some of the existing callbacks - -choose ones that do something similar to what you want to achieve. Also, you would need to get familiar -with the objects available inside of the \code{xgb.train} and \code{xgb.cv} internal environments. -} -\seealso{ -\code{\link{cb.print.evaluation}}, -\code{\link{cb.evaluation.log}}, -\code{\link{cb.reset.parameters}}, -\code{\link{cb.early.stop}}, -\code{\link{cb.save.model}}, -\code{\link{cb.cv.predict}}, -\code{\link{xgb.train}}, -\code{\link{xgb.cv}} -} diff --git a/R-package/man/cb.early.stop.Rd b/R-package/man/cb.early.stop.Rd deleted file mode 100644 index 7cd51a3ce563..000000000000 --- a/R-package/man/cb.early.stop.Rd +++ /dev/null @@ -1,62 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/callbacks.R -\name{cb.early.stop} -\alias{cb.early.stop} -\title{Callback closure to activate the early stopping.} -\usage{ -cb.early.stop( - stopping_rounds, - maximize = FALSE, - metric_name = NULL, - verbose = TRUE -) -} -\arguments{ -\item{stopping_rounds}{The number of rounds with no improvement in -the evaluation metric in order to stop the training.} - -\item{maximize}{whether to maximize the evaluation metric} - -\item{metric_name}{the name of an evaluation column to use as a criteria for early -stopping. If not set, the last column would be used. -Let's say the test data in \code{watchlist} was labelled as \code{dtest}, -and one wants to use the AUC in test data for early stopping regardless of where -it is in the \code{watchlist}, then one of the following would need to be set: -\code{metric_name='dtest-auc'} or \code{metric_name='dtest_auc'}. -All dash '-' characters in metric names are considered equivalent to '_'.} - -\item{verbose}{whether to print the early stopping information.} -} -\description{ -Callback closure to activate the early stopping. -} -\details{ -This callback function determines the condition for early stopping -by setting the \code{stop_condition = TRUE} flag in its calling frame. - -The following additional fields are assigned to the model's R object: -\itemize{ -\item \code{best_score} the evaluation score at the best iteration -\item \code{best_iteration} at which boosting iteration the best score has occurred (1-based index) -} -The Same values are also stored as xgb-attributes: -\itemize{ -\item \code{best_iteration} is stored as a 0-based iteration index (for interoperability of binary models) -\item \code{best_msg} message string is also stored. -} - -At least one data element is required in the evaluation watchlist for early stopping to work. - -Callback function expects the following values to be set in its calling frame: -\code{stop_condition}, -\code{bst_evaluation}, -\code{rank}, -\code{bst} (or \code{bst_folds} and \code{basket}), -\code{iteration}, -\code{begin_iteration}, -\code{end_iteration}, -} -\seealso{ -\code{\link{callbacks}}, -\code{\link{xgb.attr}} -} diff --git a/R-package/man/cb.evaluation.log.Rd b/R-package/man/cb.evaluation.log.Rd deleted file mode 100644 index 94f8a02e6227..000000000000 --- a/R-package/man/cb.evaluation.log.Rd +++ /dev/null @@ -1,31 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/callbacks.R -\name{cb.evaluation.log} -\alias{cb.evaluation.log} -\title{Callback closure for logging the evaluation history} -\usage{ -cb.evaluation.log() -} -\description{ -Callback closure for logging the evaluation history -} -\details{ -This callback function appends the current iteration evaluation results \code{bst_evaluation} -available in the calling parent frame to the \code{evaluation_log} list in a calling frame. - -The finalizer callback (called with \code{finalize = TURE} in the end) converts -the \code{evaluation_log} list into a final data.table. - -The iteration evaluation result \code{bst_evaluation} must be a named numeric vector. - -Note: in the column names of the final data.table, the dash '-' character is replaced with -the underscore '_' in order to make the column names more like regular R identifiers. - -Callback function expects the following values to be set in its calling frame: -\code{evaluation_log}, -\code{bst_evaluation}, -\code{iteration}. -} -\seealso{ -\code{\link{callbacks}} -} diff --git a/R-package/man/cb.print.evaluation.Rd b/R-package/man/cb.print.evaluation.Rd deleted file mode 100644 index 59b9ba65ea30..000000000000 --- a/R-package/man/cb.print.evaluation.Rd +++ /dev/null @@ -1,29 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/callbacks.R -\name{cb.print.evaluation} -\alias{cb.print.evaluation} -\title{Callback closure for printing the result of evaluation} -\usage{ -cb.print.evaluation(period = 1, showsd = TRUE) -} -\arguments{ -\item{period}{results would be printed every number of periods} - -\item{showsd}{whether standard deviations should be printed (when available)} -} -\description{ -Callback closure for printing the result of evaluation -} -\details{ -The callback function prints the result of evaluation at every \code{period} iterations. -The initial and the last iteration's evaluations are always printed. - -Callback function expects the following values to be set in its calling frame: -\code{bst_evaluation} (also \code{bst_evaluation_err} when available), -\code{iteration}, -\code{begin_iteration}, -\code{end_iteration}. -} -\seealso{ -\code{\link{callbacks}} -} diff --git a/R-package/man/cb.save.model.Rd b/R-package/man/cb.save.model.Rd deleted file mode 100644 index 7701ad9900e5..000000000000 --- a/R-package/man/cb.save.model.Rd +++ /dev/null @@ -1,40 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/callbacks.R -\name{cb.save.model} -\alias{cb.save.model} -\title{Callback closure for saving a model file.} -\usage{ -cb.save.model(save_period = 0, save_name = "xgboost.ubj") -} -\arguments{ -\item{save_period}{save the model to disk after every -\code{save_period} iterations; 0 means save the model at the end.} - -\item{save_name}{the name or path for the saved model file. - -\if{html}{\out{
}}\preformatted{ Note that the format of the model being saved is determined by the file - extension specified here (see \link{xgb.save} for details about how it works). - - It can contain a \code{\link[base]{sprintf}} formatting specifier - to include the integer iteration number in the file name. - E.g., with \code{save_name} = 'xgboost_\%04d.ubj', - the file saved at iteration 50 would be named "xgboost_0050.ubj". -}\if{html}{\out{
}}} -} -\description{ -Callback closure for saving a model file. -} -\details{ -This callback function allows to save an xgb-model file, either periodically after each \code{save_period}'s or at the end. - -Callback function expects the following values to be set in its calling frame: -\code{bst}, -\code{iteration}, -\code{begin_iteration}, -\code{end_iteration}. -} -\seealso{ -\link{xgb.save} - -\code{\link{callbacks}} -} diff --git a/R-package/man/print.xgb.cv.Rd b/R-package/man/print.xgb.cv.Rd index 05ad61eed8ac..74fc15d01fb9 100644 --- a/R-package/man/print.xgb.cv.Rd +++ b/R-package/man/print.xgb.cv.Rd @@ -23,8 +23,8 @@ including the best iteration (when available). \examples{ data(agaricus.train, package='xgboost') train <- agaricus.train -cv <- xgb.cv(data = train$data, label = train$label, nfold = 5, max_depth = 2, - eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") +cv <- xgb.cv(data = xgb.DMatrix(train$data, label = train$label), nfold = 5, max_depth = 2, + eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") print(cv) print(cv, verbose=TRUE) diff --git a/R-package/man/xgb.Callback.Rd b/R-package/man/xgb.Callback.Rd new file mode 100644 index 000000000000..b4edcd97842e --- /dev/null +++ b/R-package/man/xgb.Callback.Rd @@ -0,0 +1,248 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/callbacks.R +\name{xgb.Callback} +\alias{xgb.Callback} +\title{XGBoost Callback Constructor} +\usage{ +xgb.Callback( + cb_name = "custom_callback", + env = new.env(), + f_before_training = function(env, model, data, evals, begin_iteration, end_iteration) + NULL, + f_before_iter = function(env, model, data, evals, iteration) NULL, + f_after_iter = function(env, model, data, evals, iteration, iter_feval) NULL, + f_after_training = function(env, model, data, evals, iteration, final_feval, + prev_cb_res) NULL +) +} +\arguments{ +\item{cb_name}{Name for the callback. + +If the callback produces some non-NULL result (from executing the function passed under +\code{f_after_training}), that result will be added as an R attribute to the resulting booster +(or as a named element in the result of CV), with the attribute name specified here. + +Names of callbacks must be unique - i.e. there cannot be two callbacks with the same name.} + +\item{env}{An environment object that will be passed to the different functions in the callback. +Note that this environment will not be shared with other callbacks.} + +\item{f_before_training}{A function that will be executed before the training has started. + +If passing \code{NULL} for this or for the other function inputs, then no function will be executed. + +If passing a function, it will be called with parameters supplied as non-named arguments +matching the function signatures that are shown in the default value for each function argument.} + +\item{f_before_iter}{A function that will be executed before each boosting round. + +This function can signal whether the training should be finalized or not, by outputting +a value that evaluates to \code{TRUE} - i.e. if the output from the function provided here at +a given round is \code{TRUE}, then training will be stopped before the current iteration happens. + +Return values of \code{NULL} will be interpreted as \code{FALSE}.} + +\item{f_after_iter}{A function that will be executed after each boosting round. + +This function can signal whether the training should be finalized or not, by outputting +a value that evaluates to \code{TRUE} - i.e. if the output from the function provided here at +a given round is \code{TRUE}, then training will be stopped at that round. + +Return values of \code{NULL} will be interpreted as \code{FALSE}.} + +\item{f_after_training}{A function that will be executed after training is finished. + +This function can optionally output something non-NULL, which will become part of the R +attributes of the booster (assuming one passes \code{keep_extra_attributes=TRUE} to \link{xgb.train}) +under the name supplied for parameter \code{cb_name} imn the case of \link{xgb.train}; or a part +of the named elements in the result of \link{xgb.cv}.} +} +\value{ +An \code{xgb.Callback} object, which can be passed to \link{xgb.train} or \link{xgb.cv}. +} +\description{ +Constructor for defining the structure of callback functions that can be executed +at different stages of model training (before / after training, before / after each boosting +iteration). +} +\details{ +Arguments that will be passed to the supplied functions are as follows:\itemize{ + +\item env The same environment that is passed under argument \code{env}. + +It may be modified by the functions in order to e.g. keep tracking of what happens +across iterations or similar. + +This environment is only used by the functions supplied to the callback, and will +not be kept after the model fitting function terminates (see parameter \code{f_after_training}). + +\item model The booster object when using \link{xgb.train}, or the folds when using +\link{xgb.cv}. + +For \link{xgb.cv}, folds are a list with a structure as follows:\itemize{ +\item \code{dtrain}: The training data for the fold (as an \code{xgb.DMatrix} object). +\item \code{bst}: Rhe \code{xgb.Booster} object for the fold. +\item \code{evals}: A list containing two DMatrices, with names \code{train} and \code{test} +(\code{test} is the held-out data for the fold). +\item \code{index}: The indices of the hold-out data for that fold (base-1 indexing), +from which the \code{test} entry in \code{evals} was obtained. +} + +This object should \bold{not} be in-place modified in ways that conflict with the +training (e.g. resetting the parameters for a training update in a way that resets +the number of rounds to zero in order to overwrite rounds). + +Note that any R attributes that are assigned to the booster during the callback functions, +will not be kept thereafter as the booster object variable is not re-assigned during +training. It is however possible to set C-level attributes of the booster through +\link{xgb.attr} or \link{xgb.attributes}, which should remain available for the rest +of the iterations and after the training is done. + +For keeping variables across iterations, it's recommended to use \code{env} instead. +\item data The data to which the model is being fit, as an \code{xgb.DMatrix} object. + +Note that, for \link{xgb.cv}, this will be the full data, while data for the specific +folds can be found in the \code{model} object. + +\item evals The evaluation data, as passed under argument \code{evals} to +\link{xgb.train}. + +For \link{xgb.cv}, this will always be \code{NULL}. + +\item begin_iteration Index of the first boosting iteration that will be executed +(base-1 indexing). + +This will typically be '1', but when using training continuation, depending on the +parameters for updates, boosting rounds will be continued from where the previous +model ended, in which case this will be larger than 1. + +\item end_iteration Index of the last boostign iteration that will be executed +(base-1 indexing, inclusive of this end). + +It should match with argument \code{nrounds} passed to \link{xgb.train} or \link{xgb.cv}. + +Note that boosting might be interrupted before reaching this last iteration, for +example by using the early stopping callback \link{xgb.cb.early.stop}. + +\item iteration Index of the iteration number that is being executed (first iteration +will be the same as parameter \code{begin_iteration}, then next one will add +1, and so on). + +\item iter_feval Evaluation metrics for \code{evals} that were supplied, either +determined by the objective, or by parameter \code{feval}. + +For \link{xgb.train}, this will be a named vector with one entry per element in +\code{evals}, where the names are determined as 'evals name' + '-' + 'metric name' - for +example, if \code{evals} contains an entry named "tr" and the metric is "rmse", +this will be a one-element vector with name "tr-rmse". + +For \link{xgb.cv}, this will be a 2d matrix with dimensions \verb{[length(evals), nfolds]}, +where the row names will follow the same naming logic as the one-dimensional vector +that is passed in \link{xgb.train}. + +Note that, internally, the built-in callbacks such as \link{xgb.cb.print.evaluation} summarize +this table by calculating the row-wise means and standard deviations. + +\item final_feval The evaluation results after the last boosting round is executed +(same format as \code{iter_feval}, and will be the exact same input as passed under +\code{iter_feval} to the last round that is executed during model fitting). + +\item prev_cb_res Result from a previous run of a callback sharing the same name +(as given by parameter \code{cb_name}) when conducting training continuation, if there +was any in the booster R attributes. + +Some times, one might want to append the new results to the previous one, and this will +be done automatically by the built-in callbacks such as \link{xgb.cb.evaluation.log}, +which will append the new rows to the previous table. + +If no such previous callback result is available (which it never will when fitting +a model from start instead of updating an existing model), this will be \code{NULL}. + +For \link{xgb.cv}, which doesn't support training continuation, this will always be \code{NULL}. +} + +The following names (\code{cb_name} values) are reserved for internal callbacks:\itemize{ +\item print_evaluation +\item evaluation_log +\item reset_parameters +\item early_stop +\item save_model +\item cv_predict +\item gblinear_history +} + +The following names are reserved for other non-callback attributes:\itemize{ +\item names +\item class +\item call +\item params +\item niter +\item nfeatures +\item folds +} + +When using the built-in early stopping callback (\link{xgb.cb.early.stop}), said callback +will always be executed before the others, as it sets some booster C-level attributes +that other callbacks might also use. Otherwise, the order of execution will match with +the order in which the callbacks are passed to the model fitting function. +} +\examples{ +# Example constructing a custom callback that calculates +# squared error on the training data (no separate test set), +# and outputs the per-iteration results. +ssq_callback <- xgb.Callback( + cb_name = "ssq", + f_before_training = function(env, model, data, evals, + begin_iteration, end_iteration) { + # A vector to keep track of a number at each iteration + env$logs <- rep(NA_real_, end_iteration - begin_iteration + 1) + }, + f_after_iter = function(env, model, data, evals, iteration, iter_feval) { + # This calculates the sum of squared errors on the training data. + # Note that this can be better done by passing an 'evals' entry, + # but this demonstrates a way in which callbacks can be structured. + pred <- predict(model, data) + err <- pred - getinfo(data, "label") + sq_err <- sum(err^2) + env$logs[iteration] <- sq_err + cat( + sprintf( + "Squared error at iteration \%d: \%.2f\n", + iteration, sq_err + ) + ) + + # A return value of 'TRUE' here would signal to finalize the training + return(FALSE) + }, + f_after_training = function(env, model, data, evals, iteration, + final_feval, prev_cb_res) { + return(env$logs) + } +) + +data(mtcars) +y <- mtcars$mpg +x <- as.matrix(mtcars[, -1]) +dm <- xgb.DMatrix(x, label = y, nthread = 1) +model <- xgb.train( + data = dm, + params = list(objective = "reg:squarederror", nthread = 1), + nrounds = 5, + callbacks = list(ssq_callback), + keep_extra_attributes = TRUE +) + +# Result from 'f_after_iter' will be available as an attribute +attributes(model)$ssq +} +\seealso{ +Built-in callbacks:\itemize{ +\item \link{xgb.cb.print.evaluation} +\item \link{xgb.cb.evaluation.log} +\item \link{xgb.cb.reset.parameters} +\item \link{xgb.cb.early.stop} +\item \link{xgb.cb.save.model} +\item \link{xgb.cb.cv.predict} +\item \link{xgb.cb.gblinear.history} +} +} diff --git a/R-package/man/cb.cv.predict.Rd b/R-package/man/xgb.cb.cv.predict.Rd similarity index 53% rename from R-package/man/cb.cv.predict.Rd rename to R-package/man/xgb.cb.cv.predict.Rd index 4cabac1c9569..d2d9a084be13 100644 --- a/R-package/man/cb.cv.predict.Rd +++ b/R-package/man/xgb.cb.cv.predict.Rd @@ -1,16 +1,27 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/callbacks.R -\name{cb.cv.predict} -\alias{cb.cv.predict} -\title{Callback closure for returning cross-validation based predictions.} +\name{xgb.cb.cv.predict} +\alias{xgb.cb.cv.predict} +\title{Callback for returning cross-validation based predictions.} \usage{ -cb.cv.predict(save_models = FALSE) +xgb.cb.cv.predict(save_models = FALSE, outputmargin = FALSE) } \arguments{ -\item{save_models}{a flag for whether to save the folds' models.} +\item{save_models}{A flag for whether to save the folds' models.} + +\item{outputmargin}{Whether to save margin predictions (same effect as passing this +parameter to \link{predict.xgb.Booster}).} } \value{ -Predictions are returned inside of the \code{pred} element, which is either a vector or a matrix, +An \code{xgb.Callback} object, which can be passed to \link{xgb.cv}, +but \bold{not} to \link{xgb.train}. +} +\description{ +This callback function saves predictions for all of the test folds, +and also allows to save the folds' models. +} +\details{ +Predictions are saved inside of the \code{pred} element, which is either a vector or a matrix, depending on the number of prediction outputs per data row. The order of predictions corresponds to the order of rows in the original dataset. Note that when a custom \code{folds} list is provided in \code{xgb.cv}, the predictions would only be returned properly when this list is a @@ -19,23 +30,3 @@ meaningful when user-provided folds have overlapping indices as in, e.g., random When some of the indices in the training dataset are not included into user-provided \code{folds}, their prediction value would be \code{NA}. } -\description{ -Callback closure for returning cross-validation based predictions. -} -\details{ -This callback function saves predictions for all of the test folds, -and also allows to save the folds' models. - -It is a "finalizer" callback and it uses early stopping information whenever it is available, -thus it must be run after the early stopping callback if the early stopping is used. - -Callback function expects the following values to be set in its calling frame: -\code{bst_folds}, -\code{basket}, -\code{data}, -\code{end_iteration}, -\code{params}, -} -\seealso{ -\code{\link{callbacks}} -} diff --git a/R-package/man/xgb.cb.early.stop.Rd b/R-package/man/xgb.cb.early.stop.Rd new file mode 100644 index 000000000000..2a70f4943d92 --- /dev/null +++ b/R-package/man/xgb.cb.early.stop.Rd @@ -0,0 +1,55 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/callbacks.R +\name{xgb.cb.early.stop} +\alias{xgb.cb.early.stop} +\title{Callback to activate early stopping} +\usage{ +xgb.cb.early.stop( + stopping_rounds, + maximize = FALSE, + metric_name = NULL, + verbose = TRUE, + keep_all_iter = TRUE +) +} +\arguments{ +\item{stopping_rounds}{The number of rounds with no improvement in +the evaluation metric in order to stop the training.} + +\item{maximize}{Whether to maximize the evaluation metric.} + +\item{metric_name}{The name of an evaluation column to use as a criteria for early +stopping. If not set, the last column would be used. +Let's say the test data in \code{evals} was labelled as \code{dtest}, +and one wants to use the AUC in test data for early stopping regardless of where +it is in the \code{evals}, then one of the following would need to be set: +\code{metric_name='dtest-auc'} or \code{metric_name='dtest_auc'}. +All dash '-' characters in metric names are considered equivalent to '_'.} + +\item{verbose}{Whether to print the early stopping information.} + +\item{keep_all_iter}{Whether to keep all of the boosting rounds that were produced +in the resulting object. If passing \code{FALSE}, will only keep the boosting rounds +up to the detected best iteration, discarding the ones that come after.} +} +\value{ +An \code{xgb.Callback} object, which can be passed to \link{xgb.train} or \link{xgb.cv}. +} +\description{ +This callback function determines the condition for early stopping. + +The following attributes are assigned to the booster's object: +\itemize{ +\item \code{best_score} the evaluation score at the best iteration +\item \code{best_iteration} at which boosting iteration the best score has occurred +(0-based index for interoperability of binary models) +} + +The same values are also stored as R attributes as a result of the callback, plus an additional +attribute \code{stopped_by_max_rounds} which indicates whether an early stopping by the \code{stopping_rounds} +condition occurred. Note that the \code{best_iteration} that is stored under R attributes will follow +base-1 indexing, so it will be larger by '1' than the C-level 'best_iteration' that is accessed +through \link{xgb.attr} or \link{xgb.attributes}. + +At least one dataset is required in \code{evals} for early stopping to work. +} diff --git a/R-package/man/xgb.cb.evaluation.log.Rd b/R-package/man/xgb.cb.evaluation.log.Rd new file mode 100644 index 000000000000..4cc6ef636c66 --- /dev/null +++ b/R-package/man/xgb.cb.evaluation.log.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/callbacks.R +\name{xgb.cb.evaluation.log} +\alias{xgb.cb.evaluation.log} +\title{Callback for logging the evaluation history} +\usage{ +xgb.cb.evaluation.log() +} +\value{ +An \code{xgb.Callback} object, which can be passed to \link{xgb.train} or \link{xgb.cv}. +} +\description{ +Callback for logging the evaluation history +} +\details{ +This callback creates a table with per-iteration evaluation metrics (see parameters +\code{evals} and \code{feval} in \link{xgb.train}). + +Note: in the column names of the final data.table, the dash '-' character is replaced with +the underscore '_' in order to make the column names more like regular R identifiers. +} +\seealso{ +\link{xgb.cb.print.evaluation} +} diff --git a/R-package/man/cb.gblinear.history.Rd b/R-package/man/xgb.cb.gblinear.history.Rd similarity index 63% rename from R-package/man/cb.gblinear.history.Rd rename to R-package/man/xgb.cb.gblinear.history.Rd index 2a03c14db2f6..0ebaa4685030 100644 --- a/R-package/man/cb.gblinear.history.Rd +++ b/R-package/man/xgb.cb.gblinear.history.Rd @@ -1,37 +1,48 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/callbacks.R -\name{cb.gblinear.history} -\alias{cb.gblinear.history} -\title{Callback closure for collecting the model coefficients history of a gblinear booster -during its training.} +\name{xgb.cb.gblinear.history} +\alias{xgb.cb.gblinear.history} +\title{Callback for collecting coefficients history of a gblinear booster} \usage{ -cb.gblinear.history(sparse = FALSE) +xgb.cb.gblinear.history(sparse = FALSE) } \arguments{ -\item{sparse}{when set to FALSE/TRUE, a dense/sparse matrix is used to store the result. +\item{sparse}{when set to \code{FALSE}/\code{TRUE}, a dense/sparse matrix is used to store the result. Sparse format is useful when one expects only a subset of coefficients to be non-zero, when using the "thrifty" feature selector with fairly small number of top features selected per iteration.} } \value{ -Results are stored in the \code{coefs} element of the closure. -The \code{\link{xgb.gblinear.history}} convenience function provides an easy -way to access it. -With \code{xgb.train}, it is either a dense of a sparse matrix. -While with \code{xgb.cv}, it is a list (an element per each fold) of such -matrices. +An \code{xgb.Callback} object, which can be passed to \link{xgb.train} or \link{xgb.cv}. } \description{ -Callback closure for collecting the model coefficients history of a gblinear booster -during its training. +Callback for collecting coefficients history of a gblinear booster } \details{ To keep things fast and simple, gblinear booster does not internally store the history of linear model coefficients at each boosting iteration. This callback provides a workaround for storing the coefficients' path, by extracting them after each training iteration. -Callback function expects the following values to be set in its calling frame: -\code{bst} (or \code{bst_folds}). +This callback will construct a matrix where rows are boosting iterations and columns are +feature coefficients (same order as when calling \link{coef.xgb.Booster}, with the intercept +corresponding to the first column). + +When there is more than one coefficient per feature (e.g. multi-class classification), +the result will be reshaped into a vector where coefficients are arranged first by features and +then by class (e.g. first 1 through N coefficients will be for the first class, then +coefficients N+1 through 2N for the second class, and so on). + +If the result has only one coefficient per feature in the data, then the resulting matrix +will have column names matching with the feature names, otherwise (when there's more than +one coefficient per feature) the names will be composed as 'column name' + ':' + 'class index' +(so e.g. column 'c1' for class '0' will be named 'c1:0'). + +With \code{xgb.train}, the output is either a dense or a sparse matrix. +With with \code{xgb.cv}, it is a list (one element per each fold) of such +matrices. + +Function \link{xgb.gblinear.history} function provides an easy way to retrieve the +outputs from this callback. } \examples{ #### Binary classification: @@ -52,7 +63,7 @@ param <- list(booster = "gblinear", objective = "reg:logistic", eval_metric = "a # rate does not break the convergence, but allows us to illustrate the typical pattern of # "stochastic explosion" behaviour of this lock-free algorithm at early boosting iterations. bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 1., - callbacks = list(cb.gblinear.history())) + callbacks = list(xgb.cb.gblinear.history())) # Extract the coefficients' path and plot them vs boosting iteration number: coef_path <- xgb.gblinear.history(bst) matplot(coef_path, type = 'l') @@ -61,7 +72,7 @@ matplot(coef_path, type = 'l') # Will try the classical componentwise boosting which selects a single best feature per round: bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 200, eta = 0.8, updater = 'coord_descent', feature_selector = 'thrifty', top_k = 1, - callbacks = list(cb.gblinear.history())) + callbacks = list(xgb.cb.gblinear.history())) matplot(xgb.gblinear.history(bst), type = 'l') # Componentwise boosting is known to have similar effect to Lasso regularization. # Try experimenting with various values of top_k, eta, nrounds, @@ -69,7 +80,7 @@ matplot(xgb.gblinear.history(bst), type = 'l') # For xgb.cv: bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 100, eta = 0.8, - callbacks = list(cb.gblinear.history())) + callbacks = list(xgb.cb.gblinear.history())) # coefficients in the CV fold #3 matplot(xgb.gblinear.history(bst)[[3]], type = 'l') @@ -82,7 +93,7 @@ param <- list(booster = "gblinear", objective = "multi:softprob", num_class = 3, # For the default linear updater 'shotgun' it sometimes is helpful # to use smaller eta to reduce instability bst <- xgb.train(param, dtrain, list(tr=dtrain), nrounds = 50, eta = 0.5, - callbacks = list(cb.gblinear.history())) + callbacks = list(xgb.cb.gblinear.history())) # Will plot the coefficient paths separately for each class: matplot(xgb.gblinear.history(bst, class_index = 0), type = 'l') matplot(xgb.gblinear.history(bst, class_index = 1), type = 'l') @@ -90,11 +101,11 @@ matplot(xgb.gblinear.history(bst, class_index = 2), type = 'l') # CV: bst <- xgb.cv(param, dtrain, nfold = 5, nrounds = 70, eta = 0.5, - callbacks = list(cb.gblinear.history(FALSE))) + callbacks = list(xgb.cb.gblinear.history(FALSE))) # 1st fold of 1st class matplot(xgb.gblinear.history(bst, class_index = 0)[[1]], type = 'l') } \seealso{ -\code{\link{callbacks}}, \code{\link{xgb.gblinear.history}}. +\link{xgb.gblinear.history}, \link{coef.xgb.Booster}. } diff --git a/R-package/man/xgb.cb.print.evaluation.Rd b/R-package/man/xgb.cb.print.evaluation.Rd new file mode 100644 index 000000000000..c4f2e6991278 --- /dev/null +++ b/R-package/man/xgb.cb.print.evaluation.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/callbacks.R +\name{xgb.cb.print.evaluation} +\alias{xgb.cb.print.evaluation} +\title{Callback for printing the result of evaluation} +\usage{ +xgb.cb.print.evaluation(period = 1, showsd = TRUE) +} +\arguments{ +\item{period}{results would be printed every number of periods} + +\item{showsd}{whether standard deviations should be printed (when available)} +} +\value{ +An \code{xgb.Callback} object, which can be passed to \link{xgb.train} or \link{xgb.cv}. +} +\description{ +The callback function prints the result of evaluation at every \code{period} iterations. +The initial and the last iteration's evaluations are always printed. + +Does not leave any attribute in the booster (see \link{xgb.cb.evaluation.log} for that). +} +\seealso{ +\link{xgb.Callback} +} diff --git a/R-package/man/cb.reset.parameters.Rd b/R-package/man/xgb.cb.reset.parameters.Rd similarity index 57% rename from R-package/man/cb.reset.parameters.Rd rename to R-package/man/xgb.cb.reset.parameters.Rd index ee0a5d1bde93..c7e8638178ac 100644 --- a/R-package/man/cb.reset.parameters.Rd +++ b/R-package/man/xgb.cb.reset.parameters.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/callbacks.R -\name{cb.reset.parameters} -\alias{cb.reset.parameters} -\title{Callback closure for resetting the booster's parameters at each iteration.} +\name{xgb.cb.reset.parameters} +\alias{xgb.cb.reset.parameters} +\title{Callback for resetting the booster's parameters at each iteration.} \usage{ -cb.reset.parameters(new_params) +xgb.cb.reset.parameters(new_params) } \arguments{ \item{new_params}{a list where each element corresponds to a parameter that needs to be reset. @@ -14,23 +14,16 @@ or a function of two parameters \code{learning_rates(iteration, nrounds)} which returns a new parameter value by using the current iteration number and the total number of boosting rounds.} } +\value{ +An \code{xgb.Callback} object, which can be passed to \link{xgb.train} or \link{xgb.cv}. +} \description{ -Callback closure for resetting the booster's parameters at each iteration. +Callback for resetting the booster's parameters at each iteration. } \details{ -This is a "pre-iteration" callback function used to reset booster's parameters -at the beginning of each iteration. - Note that when training is resumed from some previous model, and a function is used to reset a parameter value, the \code{nrounds} argument in this function would be the the number of boosting rounds in the current training. -Callback function expects the following values to be set in its calling frame: -\code{bst} or \code{bst_folds}, -\code{iteration}, -\code{begin_iteration}, -\code{end_iteration}. -} -\seealso{ -\code{\link{callbacks}} +Does not leave any attribute in the booster. } diff --git a/R-package/man/xgb.cb.save.model.Rd b/R-package/man/xgb.cb.save.model.Rd new file mode 100644 index 000000000000..8ddba2f1a587 --- /dev/null +++ b/R-package/man/xgb.cb.save.model.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/callbacks.R +\name{xgb.cb.save.model} +\alias{xgb.cb.save.model} +\title{Callback for saving a model file.} +\usage{ +xgb.cb.save.model(save_period = 0, save_name = "xgboost.ubj") +} +\arguments{ +\item{save_period}{Save the model to disk after every +\code{save_period} iterations; 0 means save the model at the end.} + +\item{save_name}{The name or path for the saved model file. +It can contain a \code{\link[base]{sprintf}} formatting specifier +to include the integer iteration number in the file name. +E.g., with \code{save_name} = 'xgboost_\%04d.model', +the file saved at iteration 50 would be named "xgboost_0050.model".} +} +\value{ +An \code{xgb.Callback} object, which can be passed to \link{xgb.train}, +but \bold{not} to \link{xgb.cv}. +} +\description{ +This callback function allows to save an xgb-model file, either periodically +after each \code{save_period}'s or at the end. + +Does not leave any attribute in the booster. +} diff --git a/R-package/man/xgb.create.features.Rd b/R-package/man/xgb.create.features.Rd index 68b5619970f9..995c27459a5e 100644 --- a/R-package/man/xgb.create.features.Rd +++ b/R-package/man/xgb.create.features.Rd @@ -82,7 +82,6 @@ new.dtrain <- xgb.DMatrix( new.dtest <- xgb.DMatrix( data = new.features.test, label = agaricus.test$label, nthread = 2 ) -watchlist <- list(train = new.dtrain) bst <- xgb.train(params = param, data = new.dtrain, nrounds = nrounds, nthread = 2) # Model accuracy with new features diff --git a/R-package/man/xgb.cv.Rd b/R-package/man/xgb.cv.Rd index 9f6103a52762..cede67570683 100644 --- a/R-package/man/xgb.cv.Rd +++ b/R-package/man/xgb.cv.Rd @@ -9,14 +9,12 @@ xgb.cv( data, nrounds, nfold, - label = NULL, - missing = NA, prediction = FALSE, showsd = TRUE, metrics = list(), obj = NULL, feval = NULL, - stratified = TRUE, + stratified = "auto", folds = NULL, train_folds = NULL, verbose = TRUE, @@ -44,22 +42,25 @@ is a shorter summary: } See \code{\link{xgb.train}} for further details. -See also demo/ for walkthrough example in R.} +See also demo/ for walkthrough example in R. -\item{data}{takes an \code{xgb.DMatrix}, \code{matrix}, or \code{dgCMatrix} as the input.} +Note that, while \code{params} accepts a \code{seed} entry and will use such parameter for model training if +supplied, this seed is not used for creation of train-test splits, which instead rely on R's own RNG +system - thus, for reproducible results, one needs to call the \code{set.seed} function beforehand.} -\item{nrounds}{the max number of iterations} +\item{data}{An \code{xgb.DMatrix} object, with corresponding fields like \code{label} or bounds as required +for model training by the objective. -\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.} +\if{html}{\out{
}}\preformatted{ Note that only the basic `xgb.DMatrix` class is supported - variants such as `xgb.QuantileDMatrix` + or `xgb.ExternalDMatrix` are not supported here. +}\if{html}{\out{
}}} -\item{label}{vector of response values. Should be provided only when data is an R-matrix.} +\item{nrounds}{the max number of iterations} -\item{missing}{is only used when input is a dense matrix. By default is set to NA, which means -that NA values should be considered as 'missing' by the algorithm. -Sometimes, 0 or other extreme value might be used to represent missing values.} +\item{nfold}{the original dataset is randomly partitioned into \code{nfold} equal size subsamples.} \item{prediction}{A logical value indicating whether to return the test fold predictions -from each CV model. This parameter engages the \code{\link{cb.cv.predict}} callback.} +from each CV model. This parameter engages the \code{\link{xgb.cb.cv.predict}} callback.} \item{showsd}{\code{boolean}, whether to show standard deviation of cross validation} @@ -84,34 +85,54 @@ gradient with given prediction and dtrain.} \code{list(metric='metric-name', value='metric-value')} with given prediction and dtrain.} -\item{stratified}{a \code{boolean} indicating whether sampling of folds should be stratified -by the values of outcome labels.} +\item{stratified}{A \code{boolean} indicating whether sampling of folds should be stratified +by the values of outcome labels. For real-valued labels in regression objectives, +stratification will be done by discretizing the labels into up to 5 buckets beforehand. + +\if{html}{\out{
}}\preformatted{ If passing "auto", will be set to `TRUE` if the objective in `params` is a classification + objective (from XGBoost's built-in objectives, doesn't apply to custom ones), and to + `FALSE` otherwise. + + This parameter is ignored when `data` has a `group` field - in such case, the splitting + will be based on whole groups (note that this might make the folds have different sizes). + + Value `TRUE` here is \\bold\{not\} supported for custom objectives. +}\if{html}{\out{
}}} \item{folds}{\code{list} provides a possibility to use a list of pre-defined CV folds (each element must be a vector of test fold's indices). When folds are supplied, -the \code{nfold} and \code{stratified} parameters are ignored.} +the \code{nfold} and \code{stratified} parameters are ignored. + +\if{html}{\out{
}}\preformatted{ If `data` has a `group` field and the objective requires this field, each fold (list element) + must additionally have two attributes (retrievable through \link{attributes}) named `group_test` + and `group_train`, which should hold the `group` to assign through \link{setinfo.xgb.DMatrix} to + the resulting DMatrices. +}\if{html}{\out{
}}} \item{train_folds}{\code{list} list specifying which indicies to use for training. If \code{NULL} -(the default) all indices not specified in \code{folds} will be used for training.} +(the default) all indices not specified in \code{folds} will be used for training. + +\if{html}{\out{
}}\preformatted{ This is not supported when `data` has `group` field. +}\if{html}{\out{
}}} \item{verbose}{\code{boolean}, print the statistics during the process} \item{print_every_n}{Print each n-th iteration evaluation messages when \code{verbose>0}. Default is 1 which means all messages are printed. This parameter is passed to the -\code{\link{cb.print.evaluation}} callback.} +\code{\link{xgb.cb.print.evaluation}} callback.} \item{early_stopping_rounds}{If \code{NULL}, the early stopping function is not triggered. If set to an integer \code{k}, training with a validation set will stop if the performance doesn't improve for \code{k} rounds. -Setting this parameter engages the \code{\link{cb.early.stop}} callback.} +Setting this parameter engages the \code{\link{xgb.cb.early.stop}} callback.} \item{maximize}{If \code{feval} and \code{early_stopping_rounds} are set, then this parameter must be set as well. When it is \code{TRUE}, it means the larger the evaluation score the better. -This parameter is passed to the \code{\link{cb.early.stop}} callback.} +This parameter is passed to the \code{\link{xgb.cb.early.stop}} callback.} \item{callbacks}{a list of callback functions to perform various task during boosting. -See \code{\link{callbacks}}. Some of the callbacks are automatically created depending on the +See \code{\link{xgb.Callback}}. Some of the callbacks are automatically created depending on the parameters' values. User can provide either existing or their own callback methods in order to customize the training process.} @@ -122,27 +143,27 @@ An object of class \code{xgb.cv.synchronous} with the following elements: \itemize{ \item \code{call} a function call. \item \code{params} parameters that were passed to the xgboost library. Note that it does not -capture parameters changed by the \code{\link{cb.reset.parameters}} callback. -\item \code{callbacks} callback functions that were either automatically assigned or -explicitly passed. +capture parameters changed by the \code{\link{xgb.cb.reset.parameters}} callback. \item \code{evaluation_log} evaluation history stored as a \code{data.table} with the first column corresponding to iteration number and the rest corresponding to the CV-based evaluation means and standard deviations for the training and test CV-sets. -It is created by the \code{\link{cb.evaluation.log}} callback. +It is created by the \code{\link{xgb.cb.evaluation.log}} callback. \item \code{niter} number of boosting iterations. \item \code{nfeatures} number of features in training data. \item \code{folds} the list of CV folds' indices - either those passed through the \code{folds} parameter or randomly generated. \item \code{best_iteration} iteration number with the best evaluation metric value (only available with early stopping). -\item \code{pred} CV prediction values available when \code{prediction} is set. -It is either vector or matrix (see \code{\link{cb.cv.predict}}). -\item \code{models} a list of the CV folds' models. It is only available with the explicit -setting of the \code{cb.cv.predict(save_models = TRUE)} callback. } + +Plus other potential elements that are the result of callbacks, such as a list \code{cv_predict} with +a sub-element \code{pred} when passing \code{prediction = TRUE}, which is added by the \link{xgb.cb.cv.predict} +callback (note that one can also pass it manually under \code{callbacks} with different settings, +such as saving also the models created during cross validation); or a list \code{early_stop} which +will contain elements such as \code{best_iteration} when using the early stopping callback (\link{xgb.cb.early.stop}). } \description{ -The cross validation function of xgboost +The cross validation function of xgboost. } \details{ The original sample is randomly partitioned into \code{nfold} equal size subsamples. diff --git a/R-package/man/xgb.gblinear.history.Rd b/R-package/man/xgb.gblinear.history.Rd index 103be16f11a9..25aef7163e40 100644 --- a/R-package/man/xgb.gblinear.history.Rd +++ b/R-package/man/xgb.gblinear.history.Rd @@ -8,7 +8,7 @@ xgb.gblinear.history(model, class_index = NULL) } \arguments{ \item{model}{either an \code{xgb.Booster} or a result of \code{xgb.cv()}, trained -using the \code{cb.gblinear.history()} callback, but \bold{not} a booster +using the \link{xgb.cb.gblinear.history} callback, but \bold{not} a booster loaded from \link{xgb.load} or \link{xgb.load.raw}.} \item{class_index}{zero-based class index to extract the coefficients for only that @@ -16,23 +16,31 @@ specific class in a multinomial multiclass model. When it is NULL, all the coefficients are returned. Has no effect in non-multiclass models.} } \value{ -For an \code{xgb.train} result, a matrix (either dense or sparse) with the columns -corresponding to iteration's coefficients (in the order as \code{xgb.dump()} would -return) and the rows corresponding to boosting iterations. +For an \link{xgb.train} result, a matrix (either dense or sparse) with the columns +corresponding to iteration's coefficients and the rows corresponding to boosting iterations. -For an \code{xgb.cv} result, a list of such matrices is returned with the elements +For an \link{xgb.cv} result, a list of such matrices is returned with the elements corresponding to CV folds. + +When there is more than one coefficient per feature (e.g. multi-class classification) +and \code{class_index} is not provided, +the result will be reshaped into a vector where coefficients are arranged first by features and +then by class (e.g. first 1 through N coefficients will be for the first class, then +coefficients N+1 through 2N for the second class, and so on). } \description{ A helper function to extract the matrix of linear coefficients' history -from a gblinear model created while using the \code{cb.gblinear.history()} -callback. +from a gblinear model created while using the \link{xgb.cb.gblinear.history} +callback (which must be added manually as by default it's not used). } \details{ Note that this is an R-specific function that relies on R attributes that are not saved when using xgboost's own serialization functions like \link{xgb.load} or \link{xgb.load.raw}. -In order for a serialized model to be accepted by tgis function, one must use R +In order for a serialized model to be accepted by this function, one must use R serializers such as \link{saveRDS}. } +\seealso{ +\link{xgb.cb.gblinear.history}, \link{coef.xgb.Booster}. +} diff --git a/R-package/man/xgb.load.Rd b/R-package/man/xgb.load.Rd index 1fbe0055ed9d..e18a900e3f13 100644 --- a/R-package/man/xgb.load.Rd +++ b/R-package/man/xgb.load.Rd @@ -17,7 +17,7 @@ Load xgboost model from the binary model file. } \details{ The input file is expected to contain a model saved in an xgboost model format -using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some +using either \code{\link{xgb.save}} or \code{\link{xgb.cb.save.model}} in R, or using some appropriate methods from other xgboost interfaces. E.g., a model trained in Python and saved from there in xgboost format, could be loaded from R. diff --git a/R-package/man/xgb.slice.DMatrix.Rd b/R-package/man/xgb.slice.DMatrix.Rd index c9695996b66f..c4f7765943bb 100644 --- a/R-package/man/xgb.slice.DMatrix.Rd +++ b/R-package/man/xgb.slice.DMatrix.Rd @@ -6,14 +6,18 @@ \title{Get a new DMatrix containing the specified rows of original xgb.DMatrix object} \usage{ -xgb.slice.DMatrix(object, idxset) +xgb.slice.DMatrix(object, idxset, allow_groups = FALSE) \method{[}{xgb.DMatrix}(object, idxset, colset = NULL) } \arguments{ -\item{object}{Object of class "xgb.DMatrix"} +\item{object}{Object of class "xgb.DMatrix".} -\item{idxset}{a integer vector of indices of rows needed} +\item{idxset}{An integer vector of indices of rows needed (base-1 indexing).} + +\item{allow_groups}{Whether to allow slicing an \code{xgb.DMatrix} with \code{group} (or +equivalently \code{qid}) field. Note that in such case, the result will not have +the groups anymore - they need to be set manually through \code{setinfo}.} \item{colset}{currently not used (columns subsetting is not available)} } diff --git a/R-package/man/xgb.train.Rd b/R-package/man/xgb.train.Rd index 21c5fe7eebe4..21c8dbe16413 100644 --- a/R-package/man/xgb.train.Rd +++ b/R-package/man/xgb.train.Rd @@ -9,7 +9,7 @@ xgb.train( params = list(), data, nrounds, - watchlist = list(), + evals = list(), obj = NULL, feval = NULL, verbose = 1, @@ -158,13 +158,13 @@ List is provided in detail section.} \item{nrounds}{max number of boosting iterations.} -\item{watchlist}{named list of xgb.DMatrix datasets to use for evaluating model performance. +\item{evals}{Named list of \code{xgb.DMatrix} datasets to use for evaluating model performance. Metrics specified in either \code{eval_metric} or \code{feval} will be computed for each of these datasets during each boosting iteration, and stored in the end as a field named \code{evaluation_log} in the resulting object. When either \code{verbose>=1} or -\code{\link{cb.print.evaluation}} callback is engaged, the performance results are continuously +\code{\link{xgb.cb.print.evaluation}} callback is engaged, the performance results are continuously printed out during the training. -E.g., specifying \code{watchlist=list(validation1=mat1, validation2=mat2)} allows to track +E.g., specifying \code{evals=list(validation1=mat1, validation2=mat2)} allows to track the performance of each round's model on mat1 and mat2.} \item{obj}{customized objective function. Returns gradient and second order @@ -177,24 +177,24 @@ prediction and dtrain.} \item{verbose}{If 0, xgboost will stay silent. If 1, it will print information about performance. If 2, some additional information will be printed out. Note that setting \code{verbose > 0} automatically engages the -\code{cb.print.evaluation(period=1)} callback function.} +\code{xgb.cb.print.evaluation(period=1)} callback function.} \item{print_every_n}{Print each n-th iteration evaluation messages when \code{verbose>0}. Default is 1 which means all messages are printed. This parameter is passed to the -\code{\link{cb.print.evaluation}} callback.} +\code{\link{xgb.cb.print.evaluation}} callback.} \item{early_stopping_rounds}{If \code{NULL}, the early stopping function is not triggered. If set to an integer \code{k}, training with a validation set will stop if the performance doesn't improve for \code{k} rounds. -Setting this parameter engages the \code{\link{cb.early.stop}} callback.} +Setting this parameter engages the \code{\link{xgb.cb.early.stop}} callback.} \item{maximize}{If \code{feval} and \code{early_stopping_rounds} are set, then this parameter must be set as well. When it is \code{TRUE}, it means the larger the evaluation score the better. -This parameter is passed to the \code{\link{cb.early.stop}} callback.} +This parameter is passed to the \code{\link{xgb.cb.early.stop}} callback.} \item{save_period}{when it is non-NULL, model is saved to disk after every \code{save_period} rounds, -0 means save at the end. The saving is handled by the \code{\link{cb.save.model}} callback.} +0 means save at the end. The saving is handled by the \code{\link{xgb.cb.save.model}} callback.} \item{save_name}{the name or path for periodically saved model file.} @@ -203,12 +203,13 @@ Could be either an object of class \code{xgb.Booster}, or its raw data, or the n file with a previously saved model.} \item{callbacks}{a list of callback functions to perform various task during boosting. -See \code{\link{callbacks}}. Some of the callbacks are automatically created depending on the +See \code{\link{xgb.Callback}}. Some of the callbacks are automatically created depending on the parameters' values. User can provide either existing or their own callback methods in order to customize the training process. -\if{html}{\out{
}}\preformatted{ Note that some callbacks might try to set an evaluation log - be aware that these evaluation logs - are kept as R attributes, and thus do not get saved when using non-R serializaters like +\if{html}{\out{
}}\preformatted{ Note that some callbacks might try to leave attributes in the resulting model object, + such as an evaluation log (a `data.table` object) - be aware that these objects are kept + as R attributes, and thus do not get saved when using XGBoost's own serializaters like \link{xgb.save} (but are kept when using R serializers like \link{saveRDS}). }\if{html}{\out{
}}} @@ -233,7 +234,7 @@ The \code{xgboost} function is a simpler wrapper for \code{xgb.train}. \details{ These are the training functions for \code{xgboost}. -The \code{xgb.train} interface supports advanced features such as \code{watchlist}, +The \code{xgb.train} interface supports advanced features such as \code{evals}, customized objective and evaluation metric functions, therefore it is more flexible than the \code{xgboost} interface. @@ -269,18 +270,19 @@ Different threshold (e.g., 0.) could be specified as "error@0." The following callbacks are automatically created when certain parameters are set: \itemize{ -\item \code{cb.print.evaluation} is turned on when \code{verbose > 0}; +\item \code{xgb.cb.print.evaluation} is turned on when \code{verbose > 0}; and the \code{print_every_n} parameter is passed to it. -\item \code{cb.evaluation.log} is on when \code{watchlist} is present. -\item \code{cb.early.stop}: when \code{early_stopping_rounds} is set. -\item \code{cb.save.model}: when \code{save_period > 0} is set. +\item \code{xgb.cb.evaluation.log} is on when \code{evals} is present. +\item \code{xgb.cb.early.stop}: when \code{early_stopping_rounds} is set. +\item \code{xgb.cb.save.model}: when \code{save_period > 0} is set. } Note that objects of type \code{xgb.Booster} as returned by this function behave a bit differently from typical R objects (it's an 'altrep' list class), and it makes a separation between internal booster attributes (restricted to jsonifyable data), accessed through \link{xgb.attr} and shared between interfaces through serialization functions like \link{xgb.save}; and -R-specific attributes, accessed through \link{attributes} and \link{attr}, which are otherwise +R-specific attributes (typically the result from a callback), accessed through \link{attributes} +and \link{attr}, which are otherwise only used in the R interface, only kept when using R's serializers like \link{saveRDS}, and not anyhow used by functions like \link{predict.xgb.Booster}. @@ -305,12 +307,12 @@ dtrain <- with( dtest <- with( agaricus.test, xgb.DMatrix(data, label = label, nthread = nthread) ) -watchlist <- list(train = dtrain, eval = dtest) +evals <- list(train = dtrain, eval = dtest) ## A simple xgb.train example: param <- list(max_depth = 2, eta = 1, nthread = nthread, objective = "binary:logistic", eval_metric = "auc") -bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0) +bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0) ## An xgb.train example where custom objective and evaluation metric are ## used: @@ -331,15 +333,15 @@ evalerror <- function(preds, dtrain) { # as 'objective' and 'eval_metric' parameters in the params list: param <- list(max_depth = 2, eta = 1, nthread = nthread, objective = logregobj, eval_metric = evalerror) -bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0) +bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0) # or through the ... arguments: param <- list(max_depth = 2, eta = 1, nthread = nthread) -bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, +bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, objective = logregobj, eval_metric = evalerror) # or as dedicated 'obj' and 'feval' parameters of xgb.train: -bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, +bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, obj = logregobj, feval = evalerror) @@ -347,11 +349,11 @@ bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, param <- list(max_depth = 2, eta = 1, nthread = nthread, objective = "binary:logistic", eval_metric = "auc") my_etas <- list(eta = c(0.5, 0.1)) -bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, - callbacks = list(cb.reset.parameters(my_etas))) +bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, + callbacks = list(xgb.cb.reset.parameters(my_etas))) ## Early stopping: -bst <- xgb.train(param, dtrain, nrounds = 25, watchlist, +bst <- xgb.train(param, dtrain, nrounds = 25, evals = evals, early_stopping_rounds = 3) ## An 'xgboost' interface example: @@ -366,7 +368,7 @@ Tianqi Chen and Carlos Guestrin, "XGBoost: A Scalable Tree Boosting System", 22nd SIGKDD Conference on Knowledge Discovery and Data Mining, 2016, \url{https://arxiv.org/abs/1603.02754} } \seealso{ -\code{\link{callbacks}}, +\code{\link{xgb.Callback}}, \code{\link{predict.xgb.Booster}}, \code{\link{xgb.cv}} } diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 0f4b3ac6f6a7..93cfb8e5b4c1 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -99,15 +99,14 @@ OBJECTS= \ $(PKGROOT)/src/context.o \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ + $(PKGROOT)/src/collective/result.o \ $(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/comm_group.o \ $(PKGROOT)/src/collective/coll.o \ - $(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/tracker.o \ - $(PKGROOT)/src/collective/communicator.o \ - $(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ $(PKGROOT)/src/collective/loop.o \ $(PKGROOT)/src/collective/socket.o \ @@ -132,7 +131,4 @@ OBJECTS= \ $(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/c_api/c_api.o \ $(PKGROOT)/src/c_api/c_api_error.o \ - $(PKGROOT)/amalgamation/dmlc-minimum0.o \ - $(PKGROOT)/rabit/src/engine.o \ - $(PKGROOT)/rabit/src/rabit_c_api.o \ - $(PKGROOT)/rabit/src/allreduce_base.o + $(PKGROOT)/amalgamation/dmlc-minimum0.o diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 0c2084de940c..f160930e8a4a 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -99,15 +99,14 @@ OBJECTS= \ $(PKGROOT)/src/context.o \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ + $(PKGROOT)/src/collective/result.o \ $(PKGROOT)/src/collective/allgather.o \ $(PKGROOT)/src/collective/allreduce.o \ $(PKGROOT)/src/collective/broadcast.o \ $(PKGROOT)/src/collective/comm.o \ + $(PKGROOT)/src/collective/comm_group.o \ $(PKGROOT)/src/collective/coll.o \ - $(PKGROOT)/src/collective/communicator-inl.o \ $(PKGROOT)/src/collective/tracker.o \ - $(PKGROOT)/src/collective/communicator.o \ - $(PKGROOT)/src/collective/in_memory_communicator.o \ $(PKGROOT)/src/collective/in_memory_handler.o \ $(PKGROOT)/src/collective/loop.o \ $(PKGROOT)/src/collective/socket.o \ @@ -132,7 +131,4 @@ OBJECTS= \ $(PKGROOT)/src/common/version.o \ $(PKGROOT)/src/c_api/c_api.o \ $(PKGROOT)/src/c_api/c_api_error.o \ - $(PKGROOT)/amalgamation/dmlc-minimum0.o \ - $(PKGROOT)/rabit/src/engine.o \ - $(PKGROOT)/rabit/src/rabit_c_api.o \ - $(PKGROOT)/rabit/src/allreduce_base.o + $(PKGROOT)/amalgamation/dmlc-minimum0.o diff --git a/R-package/src/init.c b/R-package/src/init.c index f2635742ebd7..5db3218b4e1b 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -71,11 +71,12 @@ extern SEXP XGDMatrixGetDataAsCSR_R(SEXP); extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP); -extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP); +extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP, SEXP); extern SEXP XGBSetGlobalConfig_R(SEXP); extern SEXP XGBGetGlobalConfig_R(void); extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP); extern SEXP XGBoosterSlice_R(SEXP, SEXP, SEXP, SEXP); +extern SEXP XGBoosterSliceAndReplace_R(SEXP, SEXP, SEXP, SEXP); static const R_CallMethodDef CallEntries[] = { {"XGDuplicate_R", (DL_FUNC) &XGDuplicate_R, 1}, @@ -133,11 +134,12 @@ static const R_CallMethodDef CallEntries[] = { {"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3}, {"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3}, {"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3}, - {"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2}, + {"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 3}, {"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1}, {"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0}, {"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2}, {"XGBoosterSlice_R", (DL_FUNC) &XGBoosterSlice_R, 4}, + {"XGBoosterSliceAndReplace_R", (DL_FUNC) &XGBoosterSliceAndReplace_R, 4}, {NULL, NULL, 0} }; diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 5baf8d41282e..cdb9ba65c3ef 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -512,7 +512,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP return ret; } -XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { +XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups) { SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); R_API_BEGIN(); R_xlen_t len = Rf_xlength(idxset); @@ -531,7 +531,7 @@ XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle), BeginPtr(idxvec), len, &res, - 0); + Rf_asLogical(allow_groups)); } CHECK_CALL(res_code); R_SetExternalPtrAddr(ret, res); @@ -1674,3 +1674,18 @@ XGB_DLL SEXP XGBoosterSlice_R(SEXP handle, SEXP begin_layer, SEXP end_layer, SEX Rf_unprotect(1); return out; } + +XGB_DLL SEXP XGBoosterSliceAndReplace_R(SEXP handle, SEXP begin_layer, SEXP end_layer, SEXP step) { + R_API_BEGIN(); + BoosterHandle old_handle = R_ExternalPtrAddr(handle); + BoosterHandle new_handle = nullptr; + CHECK_CALL(XGBoosterSlice(old_handle, + Rf_asInteger(begin_layer), + Rf_asInteger(end_layer), + Rf_asInteger(step), + &new_handle)); + R_SetExternalPtrAddr(handle, new_handle); + CHECK_CALL(XGBoosterFree(old_handle)); + R_API_END(); + return R_NilValue; +} diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 70fd885e7f12..62be5022a3d2 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -112,9 +112,10 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP * \brief create a new dmatrix from sliced content of existing matrix * \param handle instance of data matrix to be sliced * \param idxset index set + * \param allow_groups Whether to allow slicing the DMatrix if it has a 'group' field * \return a sliced new matrix */ -XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset); +XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset, SEXP allow_groups); /*! * \brief load a data matrix into binary file @@ -535,4 +536,14 @@ XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config); */ XGB_DLL SEXP XGBoosterSlice_R(SEXP handle, SEXP begin_layer, SEXP end_layer, SEXP step); +/*! + * \brief Slice a fitted booster model (by rounds), and replace its handle with the result + * \param handle handle to the fitted booster + * \param begin_layer start of the slice + * \param end_later end of the slice; end_layer=0 is equivalent to end_layer=num_boost_round + * \param step step size of the slice + * \return NULL + */ +XGB_DLL SEXP XGBoosterSliceAndReplace_R(SEXP handle, SEXP begin_layer, SEXP end_layer, SEXP step); + #endif // XGBOOST_WRAPPER_R_H_ // NOLINT(*) diff --git a/R-package/tests/testthat.R b/R-package/tests/testthat.R index 7cf711292c48..bad6c1df3915 100644 --- a/R-package/tests/testthat.R +++ b/R-package/tests/testthat.R @@ -1,5 +1,6 @@ library(testthat) library(xgboost) +library(Matrix) test_check("xgboost", reporter = ProgressReporter) RhpcBLASctl::omp_set_num_threads(1) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 5438c8bb2235..bbb8fb323478 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -20,7 +20,7 @@ test_that("train and predict binary classification", { data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, eta = 1, nthread = n_threads, nrounds = nrounds, objective = "binary:logistic", eval_metric = "error", - watchlist = list(train = xgb.DMatrix(train$data, label = train$label)) + evals = list(train = xgb.DMatrix(train$data, label = train$label)) ), "train-error" ) @@ -152,7 +152,7 @@ test_that("train and predict softprob", { data = xgb.DMatrix(as.matrix(iris[, -5]), label = lb), max_depth = 3, eta = 0.5, nthread = n_threads, nrounds = 5, objective = "multi:softprob", num_class = 3, eval_metric = "merror", - watchlist = list(train = xgb.DMatrix(as.matrix(iris[, -5]), label = lb)) + evals = list(train = xgb.DMatrix(as.matrix(iris[, -5]), label = lb)) ), "train-merror" ) @@ -203,7 +203,7 @@ test_that("train and predict softmax", { data = xgb.DMatrix(as.matrix(iris[, -5]), label = lb), max_depth = 3, eta = 0.5, nthread = n_threads, nrounds = 5, objective = "multi:softmax", num_class = 3, eval_metric = "merror", - watchlist = list(train = xgb.DMatrix(as.matrix(iris[, -5]), label = lb)) + evals = list(train = xgb.DMatrix(as.matrix(iris[, -5]), label = lb)) ), "train-merror" ) @@ -226,7 +226,7 @@ test_that("train and predict RF", { nthread = n_threads, nrounds = 1, objective = "binary:logistic", eval_metric = "error", num_parallel_tree = 20, subsample = 0.6, colsample_bytree = 0.1, - watchlist = list(train = xgb.DMatrix(train$data, label = lb)) + evals = list(train = xgb.DMatrix(train$data, label = lb)) ) expect_equal(xgb.get.num.boosted.rounds(bst), 1) @@ -250,7 +250,7 @@ test_that("train and predict RF with softprob", { objective = "multi:softprob", eval_metric = "merror", num_class = 3, verbose = 0, num_parallel_tree = 4, subsample = 0.5, colsample_bytree = 0.5, - watchlist = list(train = xgb.DMatrix(as.matrix(iris[, -5]), label = lb)) + evals = list(train = xgb.DMatrix(as.matrix(iris[, -5]), label = lb)) ) expect_equal(xgb.get.num.boosted.rounds(bst), 15) # predict for all iterations: @@ -271,7 +271,7 @@ test_that("use of multiple eval metrics works", { data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, eta = 1, nthread = n_threads, nrounds = 2, objective = "binary:logistic", eval_metric = "error", eval_metric = "auc", eval_metric = "logloss", - watchlist = list(train = xgb.DMatrix(train$data, label = train$label)) + evals = list(train = xgb.DMatrix(train$data, label = train$label)) ), "train-error.*train-auc.*train-logloss" ) @@ -283,7 +283,7 @@ test_that("use of multiple eval metrics works", { data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, eta = 1, nthread = n_threads, nrounds = 2, objective = "binary:logistic", eval_metric = list("error", "auc", "logloss"), - watchlist = list(train = xgb.DMatrix(train$data, label = train$label)) + evals = list(train = xgb.DMatrix(train$data, label = train$label)) ), "train-error.*train-auc.*train-logloss" ) @@ -295,19 +295,19 @@ test_that("use of multiple eval metrics works", { test_that("training continuation works", { dtrain <- xgb.DMatrix(train$data, label = train$label, nthread = n_threads) - watchlist <- list(train = dtrain) + evals <- list(train = dtrain) param <- list( objective = "binary:logistic", max_depth = 2, eta = 1, nthread = n_threads ) # for the reference, use 4 iterations at once: set.seed(11) - bst <- xgb.train(param, dtrain, nrounds = 4, watchlist, verbose = 0) + bst <- xgb.train(param, dtrain, nrounds = 4, evals = evals, verbose = 0) # first two iterations: set.seed(11) - bst1 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0) + bst1 <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0) # continue for two more: - bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = bst1) + bst2 <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, xgb_model = bst1) if (!windows_flag && !solaris_flag) { expect_equal(xgb.save.raw(bst), xgb.save.raw(bst2)) } @@ -315,7 +315,7 @@ test_that("training continuation works", { expect_equal(dim(attributes(bst2)$evaluation_log), c(4, 2)) expect_equal(attributes(bst2)$evaluation_log, attributes(bst)$evaluation_log) # test continuing from raw model data - bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = xgb.save.raw(bst1)) + bst2 <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, xgb_model = xgb.save.raw(bst1)) if (!windows_flag && !solaris_flag) { expect_equal(xgb.save.raw(bst), xgb.save.raw(bst2)) } @@ -323,7 +323,7 @@ test_that("training continuation works", { # test continuing from a model in file fname <- file.path(tempdir(), "xgboost.json") xgb.save(bst1, fname) - bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, xgb_model = fname) + bst2 <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, xgb_model = fname) if (!windows_flag && !solaris_flag) { expect_equal(xgb.save.raw(bst), xgb.save.raw(bst2)) } @@ -334,7 +334,7 @@ test_that("xgb.cv works", { set.seed(11) expect_output( cv <- xgb.cv( - data = train$data, label = train$label, max_depth = 2, nfold = 5, + data = xgb.DMatrix(train$data, label = train$label), max_depth = 2, nfold = 5, eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic", eval_metric = "error", verbose = TRUE ), @@ -348,7 +348,6 @@ test_that("xgb.cv works", { expect_false(is.null(cv$folds) && is.list(cv$folds)) expect_length(cv$folds, 5) expect_false(is.null(cv$params) && is.list(cv$params)) - expect_false(is.null(cv$callbacks)) expect_false(is.null(cv$call)) }) @@ -358,13 +357,13 @@ test_that("xgb.cv works with stratified folds", { cv <- xgb.cv( data = dtrain, max_depth = 2, nfold = 5, eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic", - verbose = TRUE, stratified = FALSE + verbose = FALSE, stratified = FALSE ) set.seed(314159) cv2 <- xgb.cv( data = dtrain, max_depth = 2, nfold = 5, eta = 1., nthread = n_threads, nrounds = 2, objective = "binary:logistic", - verbose = TRUE, stratified = TRUE + verbose = FALSE, stratified = TRUE ) # Stratified folds should result in a different evaluation logs expect_true(all(cv$evaluation_log[, test_logloss_mean] != cv2$evaluation_log[, test_logloss_mean])) @@ -418,7 +417,7 @@ test_that("max_delta_step works", { dtrain <- xgb.DMatrix( agaricus.train$data, label = agaricus.train$label, nthread = n_threads ) - watchlist <- list(train = dtrain) + evals <- list(train = dtrain) param <- list( objective = "binary:logistic", eval_metric = "logloss", max_depth = 2, nthread = n_threads, @@ -426,9 +425,9 @@ test_that("max_delta_step works", { ) nrounds <- 5 # model with no restriction on max_delta_step - bst1 <- xgb.train(param, dtrain, nrounds, watchlist, verbose = 1) + bst1 <- xgb.train(param, dtrain, nrounds, evals = evals, verbose = 1) # model with restricted max_delta_step - bst2 <- xgb.train(param, dtrain, nrounds, watchlist, verbose = 1, max_delta_step = 1) + bst2 <- xgb.train(param, dtrain, nrounds, evals = evals, verbose = 1, max_delta_step = 1) # the no-restriction model is expected to have consistently lower loss during the initial iterations expect_true(all(attributes(bst1)$evaluation_log$train_logloss < attributes(bst2)$evaluation_log$train_logloss)) expect_lt(mean(attributes(bst1)$evaluation_log$train_logloss) / mean(attributes(bst2)$evaluation_log$train_logloss), 0.8) @@ -445,7 +444,7 @@ test_that("colsample_bytree works", { colnames(test_x) <- paste0("Feature_", sprintf("%03d", 1:100)) dtrain <- xgb.DMatrix(train_x, label = train_y, nthread = n_threads) dtest <- xgb.DMatrix(test_x, label = test_y, nthread = n_threads) - watchlist <- list(train = dtrain, eval = dtest) + evals <- list(train = dtrain, eval = dtest) ## Use colsample_bytree = 0.01, so that roughly one out of 100 features is chosen for ## each tree param <- list( @@ -454,7 +453,7 @@ test_that("colsample_bytree works", { eval_metric = "auc" ) set.seed(2) - bst <- xgb.train(param, dtrain, nrounds = 100, watchlist, verbose = 0) + bst <- xgb.train(param, dtrain, nrounds = 100, evals = evals, verbose = 0) xgb.importance(model = bst) # If colsample_bytree works properly, a variety of features should be used # in the 100 trees @@ -886,3 +885,57 @@ test_that("Seed in params override PRNG from R", { ) ) }) + +test_that("xgb.cv works for AFT", { + X <- matrix(c(1, -1, -1, 1, 0, 1, 1, 0), nrow = 4, byrow = TRUE) # 4x2 matrix + dtrain <- xgb.DMatrix(X, nthread = n_threads) + + params <- list(objective = 'survival:aft', learning_rate = 0.2, max_depth = 2L) + + # data must have bounds + expect_error( + xgb.cv( + params = params, + data = dtrain, + nround = 5L, + nfold = 4L, + nthread = n_threads + ) + ) + + setinfo(dtrain, 'label_lower_bound', c(2, 3, 0, 4)) + setinfo(dtrain, 'label_upper_bound', c(2, Inf, 4, 5)) + + # automatic stratified splitting is turned off + expect_warning( + xgb.cv( + params = params, data = dtrain, nround = 5L, nfold = 4L, + nthread = n_threads, stratified = TRUE, verbose = FALSE + ) + ) + + # this works without any issue + expect_no_warning( + xgb.cv(params = params, data = dtrain, nround = 5L, nfold = 4L, verbose = FALSE) + ) +}) + +test_that("xgb.cv works for ranking", { + data(iris) + x <- iris[, -(4:5)] + y <- as.integer(iris$Petal.Width) + group <- rep(50, 3) + dm <- xgb.DMatrix(x, label = y, group = group) + res <- xgb.cv( + data = dm, + params = list( + objective = "rank:pairwise", + max_depth = 3 + ), + nrounds = 3, + nfold = 2, + verbose = FALSE, + stratified = FALSE + ) + expect_equal(length(res$folds), 2L) +}) diff --git a/R-package/tests/testthat/test_callbacks.R b/R-package/tests/testthat/test_callbacks.R index c60d0c246f81..bf95a170dcfc 100644 --- a/R-package/tests/testthat/test_callbacks.R +++ b/R-package/tests/testthat/test_callbacks.R @@ -19,7 +19,7 @@ ltrain <- add.noise(train$label, 0.2) ltest <- add.noise(test$label, 0.2) dtrain <- xgb.DMatrix(train$data, label = ltrain, nthread = n_threads) dtest <- xgb.DMatrix(test$data, label = ltest, nthread = n_threads) -watchlist <- list(train = dtrain, test = dtest) +evals <- list(train = dtrain, test = dtest) err <- function(label, pr) sum((pr > 0.5) != label) / length(label) @@ -28,79 +28,125 @@ param <- list(objective = "binary:logistic", eval_metric = "error", max_depth = 2, nthread = n_threads) -test_that("cb.print.evaluation works as expected", { - - bst_evaluation <- c('train-auc' = 0.9, 'test-auc' = 0.8) - bst_evaluation_err <- NULL - begin_iteration <- 1 - end_iteration <- 7 - - f0 <- cb.print.evaluation(period = 0) - f1 <- cb.print.evaluation(period = 1) - f5 <- cb.print.evaluation(period = 5) - - expect_false(is.null(attr(f1, 'call'))) - expect_equal(attr(f1, 'name'), 'cb.print.evaluation') - - iteration <- 1 - expect_silent(f0()) - expect_output(f1(), "\\[1\\]\ttrain-auc:0.900000\ttest-auc:0.800000") - expect_output(f5(), "\\[1\\]\ttrain-auc:0.900000\ttest-auc:0.800000") - expect_null(f1()) +test_that("xgb.cb.print.evaluation works as expected for xgb.train", { + logs1 <- capture.output({ + model <- xgb.train( + data = dtrain, + params = list( + objective = "binary:logistic", + eval_metric = "auc", + max_depth = 2, + nthread = n_threads + ), + nrounds = 10, + evals = list(train = dtrain, test = dtest), + callbacks = list(xgb.cb.print.evaluation(period = 1)) + ) + }) + expect_equal(length(logs1), 10) + expect_true(all(grepl("^\\[\\d{1,2}\\]\ttrain-auc:0\\.\\d+\ttest-auc:0\\.\\d+\\s*$", logs1))) + lapply(seq(1, 10), function(x) expect_true(grepl(paste0("^\\[", x), logs1[x]))) + + logs2 <- capture.output({ + model <- xgb.train( + data = dtrain, + params = list( + objective = "binary:logistic", + eval_metric = "auc", + max_depth = 2, + nthread = n_threads + ), + nrounds = 10, + evals = list(train = dtrain, test = dtest), + callbacks = list(xgb.cb.print.evaluation(period = 2)) + ) + }) + expect_equal(length(logs2), 6) + expect_true(all(grepl("^\\[\\d{1,2}\\]\ttrain-auc:0\\.\\d+\ttest-auc:0\\.\\d+\\s*$", logs2))) + seq_matches <- c(seq(1, 10, 2), 10) + lapply(seq_along(seq_matches), function(x) expect_true(grepl(paste0("^\\[", seq_matches[x]), logs2[x]))) +}) - iteration <- 2 - expect_output(f1(), "\\[2\\]\ttrain-auc:0.900000\ttest-auc:0.800000") - expect_silent(f5()) +test_that("xgb.cb.print.evaluation works as expected for xgb.cv", { + logs1 <- capture.output({ + model <- xgb.cv( + data = dtrain, + params = list( + objective = "binary:logistic", + eval_metric = "auc", + max_depth = 2, + nthread = n_threads + ), + nrounds = 10, + nfold = 3, + callbacks = list(xgb.cb.print.evaluation(period = 1, showsd = TRUE)) + ) + }) + expect_equal(length(logs1), 10) + expect_true(all(grepl("^\\[\\d{1,2}\\]\ttrain-auc:0\\.\\d+±0\\.\\d+\ttest-auc:0\\.\\d+±0\\.\\d+\\s*$", logs1))) + lapply(seq(1, 10), function(x) expect_true(grepl(paste0("^\\[", x), logs1[x]))) + + logs2 <- capture.output({ + model <- xgb.cv( + data = dtrain, + params = list( + objective = "binary:logistic", + eval_metric = "auc", + max_depth = 2, + nthread = n_threads + ), + nrounds = 10, + nfold = 3, + callbacks = list(xgb.cb.print.evaluation(period = 2, showsd = TRUE)) + ) + }) + expect_equal(length(logs2), 6) + expect_true(all(grepl("^\\[\\d{1,2}\\]\ttrain-auc:0\\.\\d+±0\\.\\d+\ttest-auc:0\\.\\d+±0\\.\\d+\\s*$", logs2))) + seq_matches <- c(seq(1, 10, 2), 10) + lapply(seq_along(seq_matches), function(x) expect_true(grepl(paste0("^\\[", seq_matches[x]), logs2[x]))) +}) - iteration <- 7 - expect_output(f1(), "\\[7\\]\ttrain-auc:0.900000\ttest-auc:0.800000") - expect_output(f5(), "\\[7\\]\ttrain-auc:0.900000\ttest-auc:0.800000") +test_that("xgb.cb.evaluation.log works as expected for xgb.train", { + model <- xgb.train( + data = dtrain, + params = list( + objective = "binary:logistic", + eval_metric = "auc", + max_depth = 2, + nthread = n_threads + ), + nrounds = 10, + verbose = FALSE, + evals = list(train = dtrain, test = dtest), + callbacks = list(xgb.cb.evaluation.log()) + ) + logs <- attributes(model)$evaluation_log - bst_evaluation_err <- c('train-auc' = 0.1, 'test-auc' = 0.2) - expect_output(f1(), "\\[7\\]\ttrain-auc:0.900000±0.100000\ttest-auc:0.800000±0.200000") + expect_equal(nrow(logs), 10) + expect_equal(colnames(logs), c("iter", "train_auc", "test_auc")) }) -test_that("cb.evaluation.log works as expected", { - - bst_evaluation <- c('train-auc' = 0.9, 'test-auc' = 0.8) - bst_evaluation_err <- NULL - - evaluation_log <- list() - f <- cb.evaluation.log() - - expect_false(is.null(attr(f, 'call'))) - expect_equal(attr(f, 'name'), 'cb.evaluation.log') - - iteration <- 1 - expect_silent(f()) - expect_equal(evaluation_log, - list(c(iter = 1, bst_evaluation))) - iteration <- 2 - expect_silent(f()) - expect_equal(evaluation_log, - list(c(iter = 1, bst_evaluation), c(iter = 2, bst_evaluation))) - expect_silent(f(finalize = TRUE)) - expect_equal(evaluation_log, - data.table::data.table(iter = 1:2, train_auc = c(0.9, 0.9), test_auc = c(0.8, 0.8))) - - bst_evaluation_err <- c('train-auc' = 0.1, 'test-auc' = 0.2) - evaluation_log <- list() - f <- cb.evaluation.log() - - iteration <- 1 - expect_silent(f()) - expect_equal(evaluation_log, - list(c(iter = 1, c(bst_evaluation, bst_evaluation_err)))) - iteration <- 2 - expect_silent(f()) - expect_equal(evaluation_log, - list(c(iter = 1, c(bst_evaluation, bst_evaluation_err)), - c(iter = 2, c(bst_evaluation, bst_evaluation_err)))) - expect_silent(f(finalize = TRUE)) - expect_equal(evaluation_log, - data.table::data.table(iter = 1:2, - train_auc_mean = c(0.9, 0.9), train_auc_std = c(0.1, 0.1), - test_auc_mean = c(0.8, 0.8), test_auc_std = c(0.2, 0.2))) +test_that("xgb.cb.evaluation.log works as expected for xgb.cv", { + model <- xgb.cv( + data = dtrain, + params = list( + objective = "binary:logistic", + eval_metric = "auc", + max_depth = 2, + nthread = n_threads + ), + nrounds = 10, + verbose = FALSE, + nfold = 3, + callbacks = list(xgb.cb.evaluation.log()) + ) + logs <- model$evaluation_log + + expect_equal(nrow(logs), 10) + expect_equal( + colnames(logs), + c("iter", "train_auc_mean", "train_auc_std", "test_auc_mean", "test_auc_std") + ) }) @@ -109,26 +155,26 @@ param <- list(objective = "binary:logistic", eval_metric = "error", test_that("can store evaluation_log without printing", { expect_silent( - bst <- xgb.train(param, dtrain, nrounds = 10, watchlist, eta = 1, verbose = 0) + bst <- xgb.train(param, dtrain, nrounds = 10, evals = evals, eta = 1, verbose = 0) ) expect_false(is.null(attributes(bst)$evaluation_log)) expect_false(is.null(attributes(bst)$evaluation_log$train_error)) expect_lt(attributes(bst)$evaluation_log[, min(train_error)], 0.2) }) -test_that("cb.reset.parameters works as expected", { +test_that("xgb.cb.reset.parameters works as expected", { # fixed eta set.seed(111) - bst0 <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 0.9, verbose = 0) + bst0 <- xgb.train(param, dtrain, nrounds = 2, evals = evals, eta = 0.9, verbose = 0) expect_false(is.null(attributes(bst0)$evaluation_log)) expect_false(is.null(attributes(bst0)$evaluation_log$train_error)) # same eta but re-set as a vector parameter in the callback set.seed(111) my_par <- list(eta = c(0.9, 0.9)) - bst1 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, - callbacks = list(cb.reset.parameters(my_par))) + bst1 <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, + callbacks = list(xgb.cb.reset.parameters(my_par))) expect_false(is.null(attributes(bst1)$evaluation_log$train_error)) expect_equal(attributes(bst0)$evaluation_log$train_error, attributes(bst1)$evaluation_log$train_error) @@ -136,8 +182,8 @@ test_that("cb.reset.parameters works as expected", { # same eta but re-set via a function in the callback set.seed(111) my_par <- list(eta = function(itr, itr_end) 0.9) - bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, - callbacks = list(cb.reset.parameters(my_par))) + bst2 <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, + callbacks = list(xgb.cb.reset.parameters(my_par))) expect_false(is.null(attributes(bst2)$evaluation_log$train_error)) expect_equal(attributes(bst0)$evaluation_log$train_error, attributes(bst2)$evaluation_log$train_error) @@ -145,39 +191,39 @@ test_that("cb.reset.parameters works as expected", { # different eta re-set as a vector parameter in the callback set.seed(111) my_par <- list(eta = c(0.6, 0.5)) - bst3 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, - callbacks = list(cb.reset.parameters(my_par))) + bst3 <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, + callbacks = list(xgb.cb.reset.parameters(my_par))) expect_false(is.null(attributes(bst3)$evaluation_log$train_error)) expect_false(all(attributes(bst0)$evaluation_log$train_error == attributes(bst3)$evaluation_log$train_error)) # resetting multiple parameters at the same time runs with no error my_par <- list(eta = c(1., 0.5), gamma = c(1, 2), max_depth = c(4, 8)) expect_error( - bst4 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, - callbacks = list(cb.reset.parameters(my_par))) + bst4 <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, + callbacks = list(xgb.cb.reset.parameters(my_par))) , NA) # NA = no error # CV works as well expect_error( bst4 <- xgb.cv(param, dtrain, nfold = 2, nrounds = 2, verbose = 0, - callbacks = list(cb.reset.parameters(my_par))) + callbacks = list(xgb.cb.reset.parameters(my_par))) , NA) # NA = no error # expect no learning with 0 learning rate my_par <- list(eta = c(0., 0.)) - bstX <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0, - callbacks = list(cb.reset.parameters(my_par))) + bstX <- xgb.train(param, dtrain, nrounds = 2, evals = evals, verbose = 0, + callbacks = list(xgb.cb.reset.parameters(my_par))) expect_false(is.null(attributes(bstX)$evaluation_log$train_error)) er <- unique(attributes(bstX)$evaluation_log$train_error) expect_length(er, 1) expect_gt(er, 0.4) }) -test_that("cb.save.model works as expected", { +test_that("xgb.cb.save.model works as expected", { files <- c('xgboost_01.json', 'xgboost_02.json', 'xgboost.json') files <- unname(sapply(files, function(f) file.path(tempdir(), f))) for (f in files) if (file.exists(f)) file.remove(f) - bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0, + bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, eta = 1, verbose = 0, save_period = 1, save_name = file.path(tempdir(), "xgboost_%02d.json")) expect_true(file.exists(files[1])) expect_true(file.exists(files[2])) @@ -193,7 +239,7 @@ test_that("cb.save.model works as expected", { expect_equal(xgb.save.raw(bst), xgb.save.raw(b2)) # save_period = 0 saves the last iteration's model - bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0, + bst <- xgb.train(param, dtrain, nrounds = 2, evals = evals, eta = 1, verbose = 0, save_period = 0, save_name = file.path(tempdir(), 'xgboost.json')) expect_true(file.exists(files[3])) b2 <- xgb.load(files[3]) @@ -206,7 +252,7 @@ test_that("cb.save.model works as expected", { test_that("early stopping xgb.train works", { set.seed(11) expect_output( - bst <- xgb.train(param, dtrain, nrounds = 20, watchlist, eta = 0.3, + bst <- xgb.train(param, dtrain, nrounds = 20, evals = evals, eta = 0.3, early_stopping_rounds = 3, maximize = FALSE) , "Stopping. Best iteration") expect_false(is.null(xgb.attr(bst, "best_iteration"))) @@ -220,7 +266,7 @@ test_that("early stopping xgb.train works", { set.seed(11) expect_silent( - bst0 <- xgb.train(param, dtrain, nrounds = 20, watchlist, eta = 0.3, + bst0 <- xgb.train(param, dtrain, nrounds = 20, evals = evals, eta = 0.3, early_stopping_rounds = 3, maximize = FALSE, verbose = 0) ) expect_equal(attributes(bst)$evaluation_log, attributes(bst0)$evaluation_log) @@ -236,10 +282,10 @@ test_that("early stopping xgb.train works", { test_that("early stopping using a specific metric works", { set.seed(11) expect_output( - bst <- xgb.train(param[-2], dtrain, nrounds = 20, watchlist, eta = 0.6, + bst <- xgb.train(param[-2], dtrain, nrounds = 20, evals = evals, eta = 0.6, eval_metric = "logloss", eval_metric = "auc", - callbacks = list(cb.early.stop(stopping_rounds = 3, maximize = FALSE, - metric_name = 'test_logloss'))) + callbacks = list(xgb.cb.early.stop(stopping_rounds = 3, maximize = FALSE, + metric_name = 'test_logloss'))) , "Stopping. Best iteration") expect_false(is.null(xgb.attr(bst, "best_iteration"))) expect_lt(xgb.attr(bst, "best_iteration"), 19) @@ -269,7 +315,7 @@ test_that("early stopping works with titanic", { nrounds = 100, early_stopping_rounds = 3, nthread = n_threads, - watchlist = list(train = xgb.DMatrix(dtx, label = dty)) + evals = list(train = xgb.DMatrix(dtx, label = dty)) ) expect_true(TRUE) # should not crash @@ -281,10 +327,10 @@ test_that("early stopping xgb.cv works", { cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.3, nrounds = 20, early_stopping_rounds = 3, maximize = FALSE) , "Stopping. Best iteration") - expect_false(is.null(cv$best_iteration)) - expect_lt(cv$best_iteration, 19) + expect_false(is.null(cv$early_stop$best_iteration)) + expect_lt(cv$early_stop$best_iteration, 19) # the best error is min error: - expect_true(cv$evaluation_log[, test_error_mean[cv$best_iteration] == min(test_error_mean)]) + expect_true(cv$evaluation_log[, test_error_mean[cv$early_stop$best_iteration] == min(test_error_mean)]) }) test_that("prediction in xgb.cv works", { @@ -292,19 +338,19 @@ test_that("prediction in xgb.cv works", { nrounds <- 4 cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.5, nrounds = nrounds, prediction = TRUE, verbose = 0) expect_false(is.null(cv$evaluation_log)) - expect_false(is.null(cv$pred)) - expect_length(cv$pred, nrow(train$data)) - err_pred <- mean(sapply(cv$folds, function(f) mean(err(ltrain[f], cv$pred[f])))) + expect_false(is.null(cv$cv_predict$pred)) + expect_length(cv$cv_predict$pred, nrow(train$data)) + err_pred <- mean(sapply(cv$folds, function(f) mean(err(ltrain[f], cv$cv_predict$pred[f])))) err_log <- cv$evaluation_log[nrounds, test_error_mean] expect_equal(err_pred, err_log, tolerance = 1e-6) # save CV models set.seed(11) cvx <- xgb.cv(param, dtrain, nfold = 5, eta = 0.5, nrounds = nrounds, prediction = TRUE, verbose = 0, - callbacks = list(cb.cv.predict(save_models = TRUE))) + callbacks = list(xgb.cb.cv.predict(save_models = TRUE))) expect_equal(cv$evaluation_log, cvx$evaluation_log) - expect_length(cvx$models, 5) - expect_true(all(sapply(cvx$models, class) == 'xgb.Booster')) + expect_length(cvx$cv_predict$models, 5) + expect_true(all(sapply(cvx$cv_predict$models, class) == 'xgb.Booster')) }) test_that("prediction in xgb.cv works for gblinear too", { @@ -312,8 +358,8 @@ test_that("prediction in xgb.cv works for gblinear too", { p <- list(booster = 'gblinear', objective = "reg:logistic", nthread = n_threads) cv <- xgb.cv(p, dtrain, nfold = 5, eta = 0.5, nrounds = 2, prediction = TRUE, verbose = 0) expect_false(is.null(cv$evaluation_log)) - expect_false(is.null(cv$pred)) - expect_length(cv$pred, nrow(train$data)) + expect_false(is.null(cv$cv_predict$pred)) + expect_length(cv$cv_predict$pred, nrow(train$data)) }) test_that("prediction in early-stopping xgb.cv works", { @@ -321,17 +367,17 @@ test_that("prediction in early-stopping xgb.cv works", { expect_output( cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20, early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE, - prediction = TRUE, base_score = 0.5) + prediction = TRUE, base_score = 0.5, verbose = TRUE) , "Stopping. Best iteration") - expect_false(is.null(cv$best_iteration)) - expect_lt(cv$best_iteration, 19) + expect_false(is.null(cv$early_stop$best_iteration)) + expect_lt(cv$early_stop$best_iteration, 19) expect_false(is.null(cv$evaluation_log)) - expect_false(is.null(cv$pred)) - expect_length(cv$pred, nrow(train$data)) + expect_false(is.null(cv$cv_predict$pred)) + expect_length(cv$cv_predict$pred, nrow(train$data)) - err_pred <- mean(sapply(cv$folds, function(f) mean(err(ltrain[f], cv$pred[f])))) - err_log <- cv$evaluation_log[cv$best_iteration, test_error_mean] + err_pred <- mean(sapply(cv$folds, function(f) mean(err(ltrain[f], cv$cv_predict$pred[f])))) + err_log <- cv$evaluation_log[cv$early_stop$best_iteration, test_error_mean] expect_equal(err_pred, err_log, tolerance = 1e-6) err_log_last <- cv$evaluation_log[cv$niter, test_error_mean] expect_gt(abs(err_pred - err_log_last), 1e-4) @@ -341,14 +387,55 @@ test_that("prediction in xgb.cv for softprob works", { lb <- as.numeric(iris$Species) - 1 set.seed(11) expect_warning( - cv <- xgb.cv(data = as.matrix(iris[, -5]), label = lb, nfold = 4, + cv <- xgb.cv(data = xgb.DMatrix(as.matrix(iris[, -5]), label = lb), nfold = 4, eta = 0.5, nrounds = 5, max_depth = 3, nthread = n_threads, subsample = 0.8, gamma = 2, verbose = 0, prediction = TRUE, objective = "multi:softprob", num_class = 3) , NA) - expect_false(is.null(cv$pred)) - expect_equal(dim(cv$pred), c(nrow(iris), 3)) - expect_lt(diff(range(rowSums(cv$pred))), 1e-6) + expect_false(is.null(cv$cv_predict$pred)) + expect_equal(dim(cv$cv_predict$pred), c(nrow(iris), 3)) + expect_lt(diff(range(rowSums(cv$cv_predict$pred))), 1e-6) +}) + +test_that("prediction in xgb.cv works for multi-quantile", { + data(mtcars) + y <- mtcars$mpg + x <- as.matrix(mtcars[, -1]) + dm <- xgb.DMatrix(x, label = y, nthread = 1) + cv <- xgb.cv( + data = dm, + params = list( + objective = "reg:quantileerror", + quantile_alpha = c(0.1, 0.2, 0.5, 0.8, 0.9), + nthread = 1 + ), + nrounds = 5, + nfold = 3, + prediction = TRUE, + verbose = 0 + ) + expect_equal(dim(cv$cv_predict$pred), c(nrow(x), 5)) +}) + +test_that("prediction in xgb.cv works for multi-output", { + data(mtcars) + y <- mtcars$mpg + x <- as.matrix(mtcars[, -1]) + dm <- xgb.DMatrix(x, label = cbind(y, -y), nthread = 1) + cv <- xgb.cv( + data = dm, + params = list( + tree_method = "hist", + multi_strategy = "multi_output_tree", + objective = "reg:squarederror", + nthread = n_threads + ), + nrounds = 5, + nfold = 3, + prediction = TRUE, + verbose = 0 + ) + expect_equal(dim(cv$cv_predict$pred), c(nrow(x), 2)) }) test_that("prediction in xgb.cv works for multi-quantile", { @@ -368,7 +455,7 @@ test_that("prediction in xgb.cv works for multi-quantile", { prediction = TRUE, verbose = 0 ) - expect_equal(dim(cv$pred), c(nrow(x), 5)) + expect_equal(dim(cv$cv_predict$pred), c(nrow(x), 5)) }) test_that("prediction in xgb.cv works for multi-output", { @@ -389,5 +476,5 @@ test_that("prediction in xgb.cv works for multi-output", { prediction = TRUE, verbose = 0 ) - expect_equal(dim(cv$pred), c(nrow(x), 2)) + expect_equal(dim(cv$cv_predict$pred), c(nrow(x), 2)) }) diff --git a/R-package/tests/testthat/test_custom_objective.R b/R-package/tests/testthat/test_custom_objective.R index c6503124682d..d3050b152aa0 100644 --- a/R-package/tests/testthat/test_custom_objective.R +++ b/R-package/tests/testthat/test_custom_objective.R @@ -12,7 +12,7 @@ dtrain <- xgb.DMatrix( dtest <- xgb.DMatrix( agaricus.test$data, label = agaricus.test$label, nthread = n_threads ) -watchlist <- list(eval = dtest, train = dtrain) +evals <- list(eval = dtest, train = dtrain) logregobj <- function(preds, dtrain) { labels <- getinfo(dtrain, "label") @@ -33,7 +33,7 @@ param <- list(max_depth = 2, eta = 1, nthread = n_threads, num_round <- 2 test_that("custom objective works", { - bst <- xgb.train(param, dtrain, num_round, watchlist) + bst <- xgb.train(param, dtrain, num_round, evals) expect_equal(class(bst), "xgb.Booster") expect_false(is.null(attributes(bst)$evaluation_log)) expect_false(is.null(attributes(bst)$evaluation_log$eval_error)) @@ -48,7 +48,7 @@ test_that("custom objective in CV works", { }) test_that("custom objective with early stop works", { - bst <- xgb.train(param, dtrain, 10, watchlist) + bst <- xgb.train(param, dtrain, 10, evals) expect_equal(class(bst), "xgb.Booster") train_log <- attributes(bst)$evaluation_log$train_error expect_true(all(diff(train_log) <= 0)) @@ -66,7 +66,7 @@ test_that("custom objective using DMatrix attr works", { return(list(grad = grad, hess = hess)) } param$objective <- logregobjattr - bst <- xgb.train(param, dtrain, num_round, watchlist) + bst <- xgb.train(param, dtrain, num_round, evals) expect_equal(class(bst), "xgb.Booster") }) diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index 0612406444ae..548afece378c 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -41,13 +41,13 @@ test_that("xgb.DMatrix: basic construction", { params <- list(tree_method = "hist", nthread = n_threads) bst_fd <- xgb.train( - params, nrounds = 8, fd, watchlist = list(train = fd) + params, nrounds = 8, fd, evals = list(train = fd) ) bst_dgr <- xgb.train( - params, nrounds = 8, fdgr, watchlist = list(train = fdgr) + params, nrounds = 8, fdgr, evals = list(train = fdgr) ) bst_dgc <- xgb.train( - params, nrounds = 8, fdgc, watchlist = list(train = fdgc) + params, nrounds = 8, fdgc, evals = list(train = fdgc) ) raw_fd <- xgb.save.raw(bst_fd, raw_format = "ubj") @@ -243,7 +243,7 @@ test_that("xgb.DMatrix: print", { txt <- capture.output({ print(dtrain) }) - expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: label weight base_margin colnames: yes") + expect_equal(txt, "xgb.DMatrix dim: 6513 x 126 info: base_margin, label, weight colnames: yes") # DMatrix with just features dtrain <- xgb.DMatrix( @@ -724,6 +724,44 @@ test_that("xgb.DMatrix: quantile cuts look correct", { ) }) +test_that("xgb.DMatrix: slicing keeps field indicators", { + data(mtcars) + x <- as.matrix(mtcars[, -1]) + y <- mtcars[, 1] + dm <- xgb.DMatrix( + data = x, + label_lower_bound = -y, + label_upper_bound = y, + nthread = 1 + ) + idx_take <- seq(1, 5) + dm_slice <- xgb.slice.DMatrix(dm, idx_take) + + expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_lower_bound")) + expect_true(xgb.DMatrix.hasinfo(dm_slice, "label_upper_bound")) + expect_false(xgb.DMatrix.hasinfo(dm_slice, "label")) + + expect_equal(getinfo(dm_slice, "label_lower_bound"), -y[idx_take], tolerance = 1e-6) + expect_equal(getinfo(dm_slice, "label_upper_bound"), y[idx_take], tolerance = 1e-6) +}) + +test_that("xgb.DMatrix: can slice with groups", { + data(iris) + x <- as.matrix(iris[, -5]) + set.seed(123) + y <- sample(3, size = nrow(x), replace = TRUE) + group <- c(50, 50, 50) + dm <- xgb.DMatrix(x, label = y, group = group, nthread = 1) + idx_take <- seq(1, 50) + dm_slice <- xgb.slice.DMatrix(dm, idx_take, allow_groups = TRUE) + + expect_true(xgb.DMatrix.hasinfo(dm_slice, "label")) + expect_false(xgb.DMatrix.hasinfo(dm_slice, "group")) + expect_false(xgb.DMatrix.hasinfo(dm_slice, "qid")) + expect_null(getinfo(dm_slice, "group")) + expect_equal(getinfo(dm_slice, "label"), y[idx_take], tolerance = 1e-6) +}) + test_that("xgb.DMatrix: can read CSV", { txt <- paste( "1,2,3", diff --git a/R-package/tests/testthat/test_feature_weights.R b/R-package/tests/testthat/test_feature_weights.R index 4ed78c9b6cfe..54fec67cfcf5 100644 --- a/R-package/tests/testthat/test_feature_weights.R +++ b/R-package/tests/testthat/test_feature_weights.R @@ -25,7 +25,7 @@ test_that("training with feature weights works", { expect_lt(importance[1, Frequency], importance[9, Frequency]) } - for (tm in c("hist", "approx", "exact")) { + for (tm in c("hist", "approx")) { test(tm) } }) diff --git a/R-package/tests/testthat/test_glm.R b/R-package/tests/testthat/test_glm.R index 349bcce8d1f5..b59de8b62f15 100644 --- a/R-package/tests/testthat/test_glm.R +++ b/R-package/tests/testthat/test_glm.R @@ -14,37 +14,37 @@ test_that("gblinear works", { param <- list(objective = "binary:logistic", eval_metric = "error", booster = "gblinear", nthread = n_threads, eta = 0.8, alpha = 0.0001, lambda = 0.0001) - watchlist <- list(eval = dtest, train = dtrain) + evals <- list(eval = dtest, train = dtrain) n <- 5 # iterations ERR_UL <- 0.005 # upper limit for the test set error VERB <- 0 # chatterbox switch param$updater <- 'shotgun' - bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'shuffle') + bst <- xgb.train(param, dtrain, n, evals, verbose = VERB, feature_selector = 'shuffle') ypred <- predict(bst, dtest) expect_equal(length(getinfo(dtest, 'label')), 1611) expect_lt(attributes(bst)$evaluation_log$eval_error[n], ERR_UL) - bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'cyclic', - callbacks = list(cb.gblinear.history())) + bst <- xgb.train(param, dtrain, n, evals, verbose = VERB, feature_selector = 'cyclic', + callbacks = list(xgb.cb.gblinear.history())) expect_lt(attributes(bst)$evaluation_log$eval_error[n], ERR_UL) h <- xgb.gblinear.history(bst) expect_equal(dim(h), c(n, ncol(dtrain) + 1)) expect_is(h, "matrix") param$updater <- 'coord_descent' - bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'cyclic') + bst <- xgb.train(param, dtrain, n, evals, verbose = VERB, feature_selector = 'cyclic') expect_lt(attributes(bst)$evaluation_log$eval_error[n], ERR_UL) - bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'shuffle') + bst <- xgb.train(param, dtrain, n, evals, verbose = VERB, feature_selector = 'shuffle') expect_lt(attributes(bst)$evaluation_log$eval_error[n], ERR_UL) - bst <- xgb.train(param, dtrain, 2, watchlist, verbose = VERB, feature_selector = 'greedy') + bst <- xgb.train(param, dtrain, 2, evals, verbose = VERB, feature_selector = 'greedy') expect_lt(attributes(bst)$evaluation_log$eval_error[2], ERR_UL) - bst <- xgb.train(param, dtrain, n, watchlist, verbose = VERB, feature_selector = 'thrifty', - top_k = 50, callbacks = list(cb.gblinear.history(sparse = TRUE))) + bst <- xgb.train(param, dtrain, n, evals, verbose = VERB, feature_selector = 'thrifty', + top_k = 50, callbacks = list(xgb.cb.gblinear.history(sparse = TRUE))) expect_lt(attributes(bst)$evaluation_log$eval_error[n], ERR_UL) h <- xgb.gblinear.history(bst) expect_equal(dim(h), c(n, ncol(dtrain) + 1)) diff --git a/R-package/tests/testthat/test_ranking.R b/R-package/tests/testthat/test_ranking.R index e49a32025e0f..0e7db42da0b2 100644 --- a/R-package/tests/testthat/test_ranking.R +++ b/R-package/tests/testthat/test_ranking.R @@ -15,7 +15,7 @@ test_that('Test ranking with unweighted data', { params <- list(eta = 1, tree_method = 'exact', objective = 'rank:pairwise', max_depth = 1, eval_metric = 'auc', eval_metric = 'aucpr', nthread = n_threads) - bst <- xgb.train(params, dtrain, nrounds = 10, watchlist = list(train = dtrain)) + bst <- xgb.train(params, dtrain, nrounds = 10, evals = list(train = dtrain)) # Check if the metric is monotone increasing expect_true(all(diff(attributes(bst)$evaluation_log$train_auc) >= 0)) expect_true(all(diff(attributes(bst)$evaluation_log$train_aucpr) >= 0)) @@ -39,7 +39,7 @@ test_that('Test ranking with weighted data', { eta = 1, tree_method = "exact", objective = "rank:pairwise", max_depth = 1, eval_metric = "auc", eval_metric = "aucpr", nthread = n_threads ) - bst <- xgb.train(params, dtrain, nrounds = 10, watchlist = list(train = dtrain)) + bst <- xgb.train(params, dtrain, nrounds = 10, evals = list(train = dtrain)) # Check if the metric is monotone increasing expect_true(all(diff(attributes(bst)$evaluation_log$train_auc) >= 0)) expect_true(all(diff(attributes(bst)$evaluation_log$train_aucpr) >= 0)) diff --git a/R-package/tests/testthat/test_update.R b/R-package/tests/testthat/test_update.R index 3c88178e08d3..7fdc6eb84bb3 100644 --- a/R-package/tests/testthat/test_update.R +++ b/R-package/tests/testthat/test_update.R @@ -17,7 +17,7 @@ dtest <- xgb.DMatrix( win32_flag <- .Platform$OS.type == "windows" && .Machine$sizeof.pointer != 8 test_that("updating the model works", { - watchlist <- list(train = dtrain, test = dtest) + evals <- list(train = dtrain, test = dtest) # no-subsampling p1 <- list( @@ -25,19 +25,19 @@ test_that("updating the model works", { updater = "grow_colmaker,prune" ) set.seed(11) - bst1 <- xgb.train(p1, dtrain, nrounds = 10, watchlist, verbose = 0) + bst1 <- xgb.train(p1, dtrain, nrounds = 10, evals = evals, verbose = 0) tr1 <- xgb.model.dt.tree(model = bst1) # with subsampling p2 <- modifyList(p1, list(subsample = 0.1)) set.seed(11) - bst2 <- xgb.train(p2, dtrain, nrounds = 10, watchlist, verbose = 0) + bst2 <- xgb.train(p2, dtrain, nrounds = 10, evals = evals, verbose = 0) tr2 <- xgb.model.dt.tree(model = bst2) # the same no-subsampling boosting with an extra 'refresh' updater: p1r <- modifyList(p1, list(updater = 'grow_colmaker,prune,refresh', refresh_leaf = FALSE)) set.seed(11) - bst1r <- xgb.train(p1r, dtrain, nrounds = 10, watchlist, verbose = 0) + bst1r <- xgb.train(p1r, dtrain, nrounds = 10, evals = evals, verbose = 0) tr1r <- xgb.model.dt.tree(model = bst1r) # all should be the same when no subsampling expect_equal(attributes(bst1)$evaluation_log, attributes(bst1r)$evaluation_log) @@ -53,7 +53,7 @@ test_that("updating the model works", { # the same boosting with subsampling with an extra 'refresh' updater: p2r <- modifyList(p2, list(updater = 'grow_colmaker,prune,refresh', refresh_leaf = FALSE)) set.seed(11) - bst2r <- xgb.train(p2r, dtrain, nrounds = 10, watchlist, verbose = 0) + bst2r <- xgb.train(p2r, dtrain, nrounds = 10, evals = evals, verbose = 0) tr2r <- xgb.model.dt.tree(model = bst2r) # should be the same evaluation but different gains and larger cover expect_equal(attributes(bst2)$evaluation_log, attributes(bst2r)$evaluation_log) @@ -66,7 +66,7 @@ test_that("updating the model works", { # process type 'update' for no-subsampling model, refreshing the tree stats AND leaves from training data: set.seed(123) p1u <- modifyList(p1, list(process_type = 'update', updater = 'refresh', refresh_leaf = TRUE)) - bst1u <- xgb.train(p1u, dtrain, nrounds = 10, watchlist, verbose = 0, xgb_model = bst1) + bst1u <- xgb.train(p1u, dtrain, nrounds = 10, evals = evals, verbose = 0, xgb_model = bst1) tr1u <- xgb.model.dt.tree(model = bst1u) # all should be the same when no subsampling expect_equal(attributes(bst1)$evaluation_log, attributes(bst1u)$evaluation_log) @@ -79,7 +79,7 @@ test_that("updating the model works", { # same thing but with a serialized model set.seed(123) - bst1u <- xgb.train(p1u, dtrain, nrounds = 10, watchlist, verbose = 0, xgb_model = xgb.save.raw(bst1)) + bst1u <- xgb.train(p1u, dtrain, nrounds = 10, evals = evals, verbose = 0, xgb_model = xgb.save.raw(bst1)) tr1u <- xgb.model.dt.tree(model = bst1u) # all should be the same when no subsampling expect_equal(attributes(bst1)$evaluation_log, attributes(bst1u)$evaluation_log) @@ -87,7 +87,7 @@ test_that("updating the model works", { # process type 'update' for model with subsampling, refreshing only the tree stats from training data: p2u <- modifyList(p2, list(process_type = 'update', updater = 'refresh', refresh_leaf = FALSE)) - bst2u <- xgb.train(p2u, dtrain, nrounds = 10, watchlist, verbose = 0, xgb_model = bst2) + bst2u <- xgb.train(p2u, dtrain, nrounds = 10, evals = evals, verbose = 0, xgb_model = bst2) tr2u <- xgb.model.dt.tree(model = bst2u) # should be the same evaluation but different gains and larger cover expect_equal(attributes(bst2)$evaluation_log, attributes(bst2u)$evaluation_log) @@ -102,7 +102,7 @@ test_that("updating the model works", { # process type 'update' for no-subsampling model, refreshing only the tree stats from TEST data: p1ut <- modifyList(p1, list(process_type = 'update', updater = 'refresh', refresh_leaf = FALSE)) - bst1ut <- xgb.train(p1ut, dtest, nrounds = 10, watchlist, verbose = 0, xgb_model = bst1) + bst1ut <- xgb.train(p1ut, dtest, nrounds = 10, evals = evals, verbose = 0, xgb_model = bst1) tr1ut <- xgb.model.dt.tree(model = bst1ut) # should be the same evaluations but different gains and smaller cover (test data is smaller) expect_equal(attributes(bst1)$evaluation_log, attributes(bst1ut)$evaluation_log) @@ -115,18 +115,18 @@ test_that("updating works for multiclass & multitree", { dtr <- xgb.DMatrix( as.matrix(iris[, -5]), label = as.numeric(iris$Species) - 1, nthread = n_threads ) - watchlist <- list(train = dtr) + evals <- list(train = dtr) p0 <- list(max_depth = 2, eta = 0.5, nthread = n_threads, subsample = 0.6, objective = "multi:softprob", num_class = 3, num_parallel_tree = 2, base_score = 0) set.seed(121) - bst0 <- xgb.train(p0, dtr, 5, watchlist, verbose = 0) + bst0 <- xgb.train(p0, dtr, 5, evals = evals, verbose = 0) tr0 <- xgb.model.dt.tree(model = bst0) # run update process for an original model with subsampling p0u <- modifyList(p0, list(process_type = 'update', updater = 'refresh', refresh_leaf = FALSE)) bst0u <- xgb.train(p0u, dtr, nrounds = xgb.get.num.boosted.rounds(bst0), - watchlist, xgb_model = bst0, verbose = 0) + evals = evals, xgb_model = bst0, verbose = 0) tr0u <- xgb.model.dt.tree(model = bst0u) # should be the same evaluation but different gains and larger cover diff --git a/R-package/vignettes/xgboostPresentation.Rmd b/R-package/vignettes/xgboostPresentation.Rmd index 0a6432d5f9cf..fc49adc0fcee 100644 --- a/R-package/vignettes/xgboostPresentation.Rmd +++ b/R-package/vignettes/xgboostPresentation.Rmd @@ -341,10 +341,10 @@ One way to measure progress in learning of a model is to provide to **XGBoost** > in some way it is similar to what we have done above with the average error. The main difference is that below it was after building the model, and now it is during the construction that we measure errors. -For the purpose of this example, we use `watchlist` parameter. It is a list of `xgb.DMatrix`, each of them tagged with a name. +For the purpose of this example, we use the `evals` parameter. It is a list of `xgb.DMatrix` objects, each of them tagged with a name. -```{r watchlist, message=F, warning=F} -watchlist <- list(train = dtrain, test = dtest) +```{r evals, message=F, warning=F} +evals <- list(train = dtrain, test = dtest) bst <- xgb.train( data = dtrain @@ -355,7 +355,7 @@ bst <- xgb.train( , objective = "binary:logistic" ) , nrounds = 2 - , watchlist = watchlist + , evals = evals ) ``` @@ -367,7 +367,7 @@ If with your own dataset you have not such results, you should think about how y For a better understanding of the learning progression, you may want to have some specific metric or even use multiple evaluation metrics. -```{r watchlist2, message=F, warning=F} +```{r evals2, message=F, warning=F} bst <- xgb.train( data = dtrain , max_depth = 2 @@ -379,7 +379,7 @@ bst <- xgb.train( , eval_metric = "logloss" ) , nrounds = 2 - , watchlist = watchlist + , evals = evals ) ``` @@ -401,7 +401,7 @@ bst <- xgb.train( , eval_metric = "logloss" ) , nrounds = 2 - , watchlist = watchlist + , evals = evals ) ``` @@ -430,7 +430,7 @@ bst <- xgb.train( , objective = "binary:logistic" ) , nrounds = 2 - , watchlist = watchlist + , evals = evals ) ``` diff --git a/README.md b/README.md index 063b291259d8..234bd7dba76e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ - eXtreme Gradient Boosting + eXtreme Gradient Boosting =========== + [![Build Status](https://badge.buildkite.com/aca47f40a32735c00a8550540c5eeff6a4c1d246a580cae9b0.svg?branch=master)](https://buildkite.com/xgboost/xgboost-ci) [![XGBoost-CI](https://github.com/dmlc/xgboost/workflows/XGBoost-CI/badge.svg?branch=master)](https://github.com/dmlc/xgboost/actions) [![Documentation Status](https://readthedocs.org/projects/xgboost/badge/?version=latest)](https://xgboost.readthedocs.org) diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 9c373bb019ec..317a71c00d22 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -151,6 +151,7 @@ function(xgboost_set_cuda_flags target) target_include_directories( ${target} PRIVATE ${xgboost_SOURCE_DIR}/gputreeshap + ${xgboost_SOURCE_DIR}/rabit/include ${CUDAToolkit_INCLUDE_DIRS}) if(MSVC) @@ -289,7 +290,7 @@ macro(xgboost_target_link_libraries target) endif() if(USE_NVTX) - target_link_libraries(${target} PRIVATE CUDA::nvToolsExt) + target_link_libraries(${target} PRIVATE CUDA::nvtx3) endif() if(MINGW) diff --git a/demo/dask/cpu_survival.py b/demo/dask/cpu_survival.py index 8bf464ce21d3..44032bab207f 100644 --- a/demo/dask/cpu_survival.py +++ b/demo/dask/cpu_survival.py @@ -6,6 +6,7 @@ import os +import dask.array as da import dask.dataframe as dd from dask.distributed import Client, LocalCluster @@ -13,7 +14,7 @@ from xgboost.dask import DaskDMatrix -def main(client): +def main(client: Client) -> da.Array: # Load an example survival data from CSV into a Dask data frame. # The Veterans' Administration Lung Cancer Trial # The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980) diff --git a/demo/dask/cpu_training.py b/demo/dask/cpu_training.py index 2bee444f7a89..b3a389458987 100644 --- a/demo/dask/cpu_training.py +++ b/demo/dask/cpu_training.py @@ -11,12 +11,12 @@ from xgboost.dask import DaskDMatrix -def main(client): +def main(client: Client) -> None: # generate some random data for demonstration m = 100000 n = 100 rng = da.random.default_rng(1) - X = rng.normal(size=(m, n)) + X = rng.normal(size=(m, n), chunks=(10000, -1)) y = X.sum(axis=1) # DaskDMatrix acts like normal DMatrix, works as a proxy for local @@ -40,7 +40,7 @@ def main(client): # you can pass output directly into `predict` too. prediction = dxgb.predict(client, bst, dtrain) print("Evaluation history:", history) - return prediction + print("Error:", da.sqrt((prediction - y) ** 2).mean().compute()) if __name__ == "__main__": diff --git a/demo/dask/dask_callbacks.py b/demo/dask/dask_callbacks.py index 4a7ec0f191cb..1a15b918a534 100644 --- a/demo/dask/dask_callbacks.py +++ b/demo/dask/dask_callbacks.py @@ -3,6 +3,8 @@ ==================================== """ +from typing import Any + import numpy as np from dask.distributed import Client, LocalCluster from dask_ml.datasets import make_regression @@ -13,7 +15,7 @@ from xgboost.dask import DaskDMatrix -def probability_for_going_backward(epoch): +def probability_for_going_backward(epoch: int) -> float: return 0.999 / (1.0 + 0.05 * np.log(1.0 + epoch)) @@ -23,7 +25,9 @@ class CustomEarlyStopping(xgb.callback.TrainingCallback): In the beginning, allow the metric to become worse with a probability of 0.999. As boosting progresses, the probability should be adjusted downward""" - def __init__(self, *, validation_set, target_metric, maximize, seed): + def __init__( + self, *, validation_set: str, target_metric: str, maximize: bool, seed: int + ) -> None: self.validation_set = validation_set self.target_metric = target_metric self.maximize = maximize @@ -34,7 +38,9 @@ def __init__(self, *, validation_set, target_metric, maximize, seed): else: self.better = lambda x, y: x < y - def after_iteration(self, model, epoch, evals_log): + def after_iteration( + self, model: Any, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog + ) -> bool: metric_history = evals_log[self.validation_set][self.target_metric] if len(metric_history) < 2 or self.better( metric_history[-1], metric_history[-2] @@ -42,7 +48,7 @@ def after_iteration(self, model, epoch, evals_log): return False # continue training p = probability_for_going_backward(epoch) go_backward = self.rng.choice(2, size=(1,), replace=True, p=[1 - p, p]).astype( - np.bool + np.bool_ )[0] print( "The validation metric went into the wrong direction. " @@ -54,7 +60,7 @@ def after_iteration(self, model, epoch, evals_log): return True # stop training -def main(client): +def main(client: Client) -> None: m = 100000 n = 100 X, y = make_regression(n_samples=m, n_features=n, chunks=200, random_state=0) diff --git a/demo/dask/sklearn_cpu_training.py b/demo/dask/sklearn_cpu_training.py index e91babb8407b..38a53c6ca71c 100644 --- a/demo/dask/sklearn_cpu_training.py +++ b/demo/dask/sklearn_cpu_training.py @@ -9,7 +9,7 @@ from xgboost import dask as dxgb -def main(client): +def main(client: Client) -> dxgb.Booster: # generate some random data for demonstration n = 100 m = 10000 diff --git a/demo/dask/sklearn_gpu_training.py b/demo/dask/sklearn_gpu_training.py index 7686909951e6..6161bf9a3402 100644 --- a/demo/dask/sklearn_gpu_training.py +++ b/demo/dask/sklearn_gpu_training.py @@ -12,7 +12,7 @@ from xgboost import dask as dxgb -def main(client): +def main(client: Client) -> dxgb.Booster: # generate some random data for demonstration n = 100 m = 1000000 diff --git a/demo/guide-python/external_memory.py b/demo/guide-python/external_memory.py index e4d1895d1a1a..b19f550c9149 100644 --- a/demo/guide-python/external_memory.py +++ b/demo/guide-python/external_memory.py @@ -84,7 +84,7 @@ def main(tmpdir: str) -> xgboost.Booster: it = Iterator(files) # For non-data arguments, specify it here once instead of passing them by the `next` # method. - missing = np.NaN + missing = np.nan Xy = xgboost.DMatrix(it, missing=missing, enable_categorical=False) # ``approx`` is also supported, but less efficient due to sketching. GPU behaves diff --git a/doc/conf.py b/doc/conf.py index 68ec39181ba0..ec58c5a5d456 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -250,7 +250,7 @@ def is_readthedocs_build(): html_theme_options = {"logo_only": True} -html_logo = "https://raw.githubusercontent.com/dmlc/dmlc.github.io/master/img/logo-m/xgboost.png" +html_logo = "https://xgboost.ai/images/logo/xgboost-logo-ng.png" html_css_files = ["css/custom.css"] diff --git a/doc/contrib/unit_tests.rst b/doc/contrib/unit_tests.rst index 662a632e27db..908e5ed99fa9 100644 --- a/doc/contrib/unit_tests.rst +++ b/doc/contrib/unit_tests.rst @@ -144,6 +144,14 @@ which provides higher flexibility. For example: ctest --verbose +If you need to debug errors on Windows using the debugger from VS, you can append the gtest flags in `test_main.cc`: + +.. code-block:: + + ::testing::GTEST_FLAG(filter) = "Suite.Test"; + ::testing::GTEST_FLAG(repeat) = 10; + + *********************************************** Sanitizers: Detect memory errors and data races *********************************************** diff --git a/doc/index.rst b/doc/index.rst index a2ae9bbd39da..7b241c0a17d2 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -28,7 +28,7 @@ Contents Python Package R Package JVM Package - Ruby Package + Ruby Package Swift Package Julia Package C Package diff --git a/doc/parameter.rst b/doc/parameter.rst index 7898bb363549..00f0eaea6193 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -118,7 +118,7 @@ Parameters for Tree Booster - All ``colsample_by*`` parameters have a range of (0, 1], the default value of 1, and specify the fraction of columns to be subsampled. - ``colsample_bytree`` is the subsample ratio of columns when constructing each tree. Subsampling occurs once for every tree constructed. - ``colsample_bylevel`` is the subsample ratio of columns for each level. Subsampling occurs once for every new depth level reached in a tree. Columns are subsampled from the set of columns chosen for the current tree. - - ``colsample_bynode`` is the subsample ratio of columns for each node (split). Subsampling occurs once every time a new split is evaluated. Columns are subsampled from the set of columns chosen for the current level. + - ``colsample_bynode`` is the subsample ratio of columns for each node (split). Subsampling occurs once every time a new split is evaluated. Columns are subsampled from the set of columns chosen for the current level. This is not supported by the exact tree method. - ``colsample_by*`` parameters work cumulatively. For instance, the combination ``{'colsample_bytree':0.5, 'colsample_bylevel':0.5, 'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at @@ -489,7 +489,7 @@ Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``) These are parameters specific to learning to rank task. See :doc:`Learning to Rank ` for an in-depth explanation. -* ``lambdarank_pair_method`` [default = ``mean``] +* ``lambdarank_pair_method`` [default = ``topk``] How to construct pairs for pair-wise learning. @@ -500,7 +500,13 @@ These are parameters specific to learning to rank task. See :doc:`Learning to Ra It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``. -* ``lambdarank_unbiased`` [default = ``false``] +* ``lambdarank_normalization`` [default = ``true``] + + .. versionadded:: 2.1.0 + + Whether to normalize the leaf value by lambda gradient. This can sometimes stagnate the training progress. + +* ``lambdarank_unbiased`` [default = ``false``] Specify whether do we need to debias input click data. diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 4b145f9a95b2..3544f88b5731 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -237,41 +237,44 @@ For most of the use cases with GPUs, the `Dask-CUDA `_, for example, for GPUs and you can use Dask Cloud Provider to `deploy Dask clusters in the cloud `_. See the `Dask documentation for a more comprehensive list `_. +Using Dask's ``LocalCluster`` is convenient for getting started quickly on a local machine. Once you're ready to scale your work, though, there are a number of ways to deploy Dask on a distributed cluster. You can use `Dask-CUDA `_, for example, for GPUs and you can use Dask Cloud Provider to `deploy Dask clusters in the cloud `_. See the `Dask documentation for a more comprehensive list `_. In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernetes `_: .. code-block:: python - from dask_kubernetes import KubeCluster # Need to install the ``dask-kubernetes`` package + from dask_kubernetes.operator import KubeCluster # Need to install the ``dask-kubernetes`` package + from dask_kubernetes.operator.kubecluster.kubecluster import CreateMode + from dask.distributed import Client from xgboost import dask as dxgb - import dask import dask.array as da - dask.config.set({"kubernetes.scheduler-service-type": "LoadBalancer", - "kubernetes.scheduler-service-wait-timeout": 360, - "distributed.comm.timeouts.connect": 360}) - def main(): - '''Connect to a remote kube cluster with GPU nodes and run training on it.''' + '''Connect to a remote kube cluster with GPU nodes and run training on it.''' m = 1000 n = 10 kWorkers = 2 # assuming you have 2 GPU nodes on that cluster. # You need to work out the worker-spec yourself. See document in dask_kubernetes for # its usage. Here we just want to show that XGBoost works on various clusters. - cluster = KubeCluster.from_yaml('worker-spec.yaml', deploy_mode='remote') - cluster.scale(kWorkers) # scale to use all GPUs - with Client(cluster) as client: - X = da.random.random(size=(m, n), chunks=100) - y = da.random.random(size=(m, ), chunks=100) + # See notes below for why we use pre-allocated cluster. + with KubeCluster( + name="xgboost-test", + image="my-image-name:latest", + n_workers=kWorkers, + create_mode=CreateMode.CONNECT_ONLY, + shutdown_on_close=False, + ) as cluster: + with Client(cluster) as client: + X = da.random.random(size=(m, n), chunks=100) + y = X.sum(axis=1) - regressor = dxgb.DaskXGBRegressor(n_estimators=10, missing=0.0) - regressor.client = client - regressor.set_params(tree_method='hist', device="cuda") - regressor.fit(X, y, eval_set=[(X, y)]) + regressor = dxgb.DaskXGBRegressor(n_estimators=10, missing=0.0) + regressor.client = client + regressor.set_params(tree_method='hist', device="cuda") + regressor.fit(X, y, eval_set=[(X, y)]) if __name__ == '__main__': @@ -279,11 +282,46 @@ In the example below, a ``KubeCluster`` is used for `deploying Dask on Kubernete # main function will connect to that cluster and start training xgboost model. main() + Different cluster classes might have subtle differences like network configuration, or specific cluster implementation might contains bugs that we are not aware of. Open an issue if such case is found and there's no documentation on how to resolve it in that cluster implementation. +An interesting aspect of the Kubernetes cluster is that the pods may become available +after the Dask workflow has begun, which can cause issues with distributed XGBoost since +XGBoost expects the nodes used by input data to remain unchanged during training. To use +Kubernetes clusters, it is necessary to wait for all the pods to be online before +submitting XGBoost tasks. One can either create a wait function in Python or simply +pre-allocate a cluster with k8s tools (like ``kubectl``) before running dask workflows. To +pre-allocate a cluster, we can first generate the cluster spec using dask kubernetes: + +.. code-block:: python + + import json + + from dask_kubernetes.operator import make_cluster_spec + + spec = make_cluster_spec(name="xgboost-test", image="my-image-name:latest", n_workers=16) + with open("cluster-spec.json", "w") as fd: + json.dump(spec, fd, indent=2) + +.. code-block:: sh + + kubectl apply -f ./cluster-spec.json + + +Check whether the pods are available: + +.. code-block:: sh + + kubectl get pods + +Once all pods have been initialized, the Dask XGBoost workflow can be run, as in the +previous example. It is important to ensure that the cluster sets the parameter +``create_mode=CreateMode.CONNECT_ONLY`` and optionally ``shutdown_on_close=False`` if you +do not want to shut down the cluster after a single job. + ******* Threads ******* diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst index 015f736e08eb..15a611bd0c32 100644 --- a/doc/tutorials/learning_to_rank.rst +++ b/doc/tutorials/learning_to_rank.rst @@ -48,11 +48,11 @@ Notice that the samples are sorted based on their query index in a non-decreasin import xgboost as xgb # Make a synthetic ranking dataset for demonstration - seed = 1994 + seed = 1994 X, y = make_classification(random_state=seed) rng = np.random.default_rng(seed) n_query_groups = 3 - qid = rng.integers(0, 3, size=X.shape[0]) + qid = rng.integers(0, n_query_groups, size=X.shape[0]) # Sort the inputs based on query index sorted_idx = np.argsort(qid) @@ -65,14 +65,14 @@ The simplest way to train a ranking model is by using the scikit-learn estimator .. code-block:: python ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk") - ranker.fit(X, y, qid=qid) + ranker.fit(X, y, qid=qid[sorted_idx]) Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows: .. code-block:: python df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])]) - df["qid"] = qid + df["qid"] = qid[sorted_idx] ranker.fit(df, y) # No need to pass qid as a separate argument from sklearn.model_selection import StratifiedGroupKFold, cross_val_score @@ -146,7 +146,8 @@ The consideration of effective pairs also applies to the choice of pair method ( When using the mean strategy for generating pairs, where the target metric (like ``NDCG``) is computed over the whole query list, users can specify how many pairs should be generated per each document, by setting the ``lambdarank_num_pair_per_sample``. XGBoost will randomly sample ``lambdarank_num_pair_per_sample`` pairs for each element in the query group (:math:`|pairs| = |query| \times num\_pairsample`). Often, setting it to 1 can produce reasonable results. In cases where performance is inadequate due to insufficient number of effective pairs being generated, set ``lambdarank_num_pair_per_sample`` to a higher value. As more document pairs are generated, more effective pairs will be generated as well. -On the other hand, if you are prioritizing the top :math:`k` documents, the ``lambdarank_num_pair_per_sample`` should be set slightly higher than :math:`k` (with a few more documents) to obtain a good training result. +On the other hand, if you are prioritizing the top :math:`k` documents, the ``lambdarank_num_pair_per_sample`` should be set slightly higher than :math:`k` (with a few more documents) to obtain a good training result. Lastly, XGBoost employs additional regularization for learning to rank objectives, which can be disabled by setting the ``lambdarank_normalization`` to ``False``. + **Summary** If you have large amount of training data: diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 1f94c9b2fd1d..9abe72b87859 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -1,20 +1,18 @@ /** - * Copyright 2015-2023 by XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file base.h * \brief Defines configuration macros and basic types for xgboost. */ #ifndef XGBOOST_BASE_H_ #define XGBOOST_BASE_H_ -#include -#include +#include // for omp_uint, omp_ulong -#include -#include -#include -#include -#include -#include +#include // for int32_t, uint64_t, int16_t +#include // for ostream +#include // for string +#include // for pair +#include // for vector /*! * \brief string flag for R library, to leave hooks when needed. @@ -86,34 +84,31 @@ #endif // !defined(XGBOOST_MM_PREFETCH_PRESENT) && !defined() -/*! \brief namespace of xgboost*/ namespace xgboost { - /*! \brief unsigned integer type used for feature index. */ -using bst_uint = uint32_t; // NOLINT +using bst_uint = std::uint32_t; // NOLINT /*! \brief unsigned long integers */ -using bst_ulong = uint64_t; // NOLINT +using bst_ulong = std::uint64_t; // NOLINT /*! \brief float type, used for storing statistics */ using bst_float = float; // NOLINT /*! \brief Categorical value type. */ -using bst_cat_t = int32_t; // NOLINT +using bst_cat_t = std::int32_t; // NOLINT /*! \brief Type for data column (feature) index. */ -using bst_feature_t = uint32_t; // NOLINT -/*! \brief Type for histogram bin index. */ -using bst_bin_t = int32_t; // NOLINT -/*! \brief Type for data row index. - * - * Be careful `std::size_t' is implementation-defined. Meaning that the binary - * representation of DMatrix might not be portable across platform. Booster model should - * be portable as parameters are floating points. +using bst_feature_t = std::uint32_t; // NOLINT +/** + * @brief Type for histogram bin index. We sometimes use -1 to indicate invalid bin. */ -using bst_row_t = std::size_t; // NOLINT +using bst_bin_t = std::int32_t; // NOLINT +/** + * @brief Type for data row index (sample). + */ +using bst_idx_t = std::uint64_t; // NOLINT /*! \brief Type for tree node index. */ using bst_node_t = std::int32_t; // NOLINT /*! \brief Type for ranking group index. */ using bst_group_t = std::uint32_t; // NOLINT /** - * \brief Type for indexing into output targets. + * @brief Type for indexing into output targets. */ using bst_target_t = std::uint32_t; // NOLINT /** @@ -306,8 +301,7 @@ class GradientPairInt64 { XGBOOST_DEVICE bool operator==(const GradientPairInt64 &rhs) const { return grad_ == rhs.grad_ && hess_ == rhs.hess_; } - friend std::ostream &operator<<(std::ostream &os, - const GradientPairInt64 &g) { + friend std::ostream &operator<<(std::ostream &os, const GradientPairInt64 &g) { os << g.GetQuantisedGrad() << "/" << g.GetQuantisedHess(); return os; } @@ -323,7 +317,7 @@ using omp_ulong = dmlc::omp_ulong; // NOLINT /*! \brief define unsigned int for openmp loop */ using bst_omp_uint = dmlc::omp_uint; // NOLINT /*! \brief Type used for representing version number in binary form.*/ -using XGBoostVersionT = int32_t; +using XGBoostVersionT = std::int32_t; } // namespace xgboost #endif // XGBOOST_BASE_H_ diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 795c78946118..4b60fe01a546 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1,5 +1,5 @@ /** - * Copyright 2015~2023 by XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file c_api.h * \author Tianqi Chen * \brief C API of XGBoost, used for interfacing to other languages. @@ -639,21 +639,14 @@ XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, * \param len length of array * \return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, - const char *field, - const float *array, +XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const float *array, bst_ulong len); -/*! - * \brief set uint32 vector to a content in info - * \param handle a instance of data matrix - * \param field field name - * \param array pointer to unsigned int vector - * \param len length of array - * \return 0 when success, -1 when failure happens +/** + * @deprecated since 2.1.0 + * + * Use @ref XGDMatrixSetInfoFromInterface instead. */ -XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, - const char *field, - const unsigned *array, +XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *array, bst_ulong len); /*! @@ -725,42 +718,13 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, bst_ulong *size, const char ***out_features); -/*! - * \brief Set meta info from dense matrix. Valid field names are: - * - * - label - * - weight - * - base_margin - * - group - * - label_lower_bound - * - label_upper_bound - * - feature_weights +/** + * @deprecated since 2.1.0 * - * \param handle An instance of data matrix - * \param field Field name - * \param data Pointer to consecutive memory storing data. - * \param size Size of the data, this is relative to size of type. (Meaning NOT number - * of bytes.) - * \param type Indicator of data type. This is defined in xgboost::DataType enum class. - * - float = 1 - * - double = 2 - * - uint32_t = 3 - * - uint64_t = 4 - * \return 0 when success, -1 when failure happens + * Use @ref XGDMatrixSetInfoFromInterface instead. */ -XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, - void const *data, bst_ulong size, int type); - -/*! - * \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix - * \param handle a instance of data matrix - * \param group pointer to group size - * \param len length of array - * \return 0 when success, -1 when failure happens - */ -XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, - const unsigned *group, - bst_ulong len); +XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data, + bst_ulong size, int type); /*! * \brief get float info vector from matrix. @@ -1153,8 +1117,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values, * * @return 0 when success, -1 when failure happens */ -XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array_interface, - char const *c_json_config, DMatrixHandle m, +XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *values, + char const *config, DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim, const float **out_result); @@ -1550,16 +1514,37 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *config, * * @brief Experimental support for exposing internal communicator in XGBoost. * + * @note This is still under development. + * + * The collective communicator in XGBoost evolved from the `rabit` project of dmlc but has + * changed significantly since its adoption. It consists of a tracker and a set of + * workers. The tracker is responsible for bootstrapping the communication group and + * handling centralized tasks like logging. The workers are actual communicators + * performing collective tasks like allreduce. + * + * To use the collective implementation, one needs to first create a tracker with + * corresponding parameters, then get the arguments for workers using + * XGTrackerWorkerArgs(). The obtained arguments can then be passed to the + * XGCommunicatorInit() function. Call to XGCommunicatorInit() must be accompanied with a + * XGCommunicatorFinalize() call for cleanups. Please note that the communicator uses + * `std::thread` in C++, which has undefined behavior in a C++ destructor due to the + * runtime shutdown sequence. It's preferable to call XGCommunicatorFinalize() before the + * runtime is shutting down. This requirement is similar to a Python thread or socket, + * which should not be relied upon in a `__del__` function. + * + * Since it's used as a part of XGBoost, errors will be returned when a XGBoost function + * is called, for instance, training a booster might return a connection error. + * * @{ */ /** - * @brief Handle to tracker. + * @brief Handle to the tracker. * * There are currently two types of tracker in XGBoost, first one is `rabit`, while the - * other one is `federated`. + * other one is `federated`. `rabit` is used for normal collective communication, while + * `federated` is used for federated learning. * - * This is still under development. */ typedef void *TrackerHandle; /* NOLINT */ @@ -1568,17 +1553,23 @@ typedef void *TrackerHandle; /* NOLINT */ * * @param config JSON encoded parameters. * - * - dmlc_communicator: String, the type of tracker to create. Available options are `rabit` - * and `federated`. + * - dmlc_communicator: String, the type of tracker to create. Available options are + * `rabit` and `federated`. See @ref TrackerHandle for more info. * - n_workers: Integer, the number of workers. * - port: (Optional) Integer, the port this tracker should listen to. - * - timeout: (Optional) Integer, timeout in seconds for various networking operations. + * - timeout: (Optional) Integer, timeout in seconds for various networking + operations. Default is 300 seconds. * * Some configurations are `rabit` specific: + * * - host: (Optional) String, Used by the the `rabit` tracker to specify the address of the host. + * This can be useful when the communicator cannot reliably obtain the host address. + * - sortby: (Optional) Integer. + * + 0: Sort workers by their host name. + * + 1: Sort workers by task IDs. * * Some `federated` specific configurations: - * - federated_secure: Boolean, whether this is a secure server. + * - federated_secure: Boolean, whether this is a secure server. False for testing. * - server_key_path: Path to the server key. Used only if this is a secure server. * - server_cert_path: Path to the server certificate. Used only if this is a secure server. * - client_cert_path: Path to the client certificate. Used only if this is a secure server. @@ -1591,7 +1582,7 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle); /** * @brief Get the arguments needed for running workers. This should be called after - * XGTrackerRun() and XGTrackerWait() + * XGTrackerRun(). * * @param handle The handle to the tracker. * @param args The arguments returned as a JSON document. @@ -1601,16 +1592,19 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle); XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args); /** - * @brief Run the tracker. + * @brief Start the tracker. The tracker runs in the background and this function returns + * once the tracker is started. * * @param handle The handle to the tracker. + * @param config Unused at the moment, preserved for the future. * * @return 0 for success, -1 for failure. */ -XGB_DLL int XGTrackerRun(TrackerHandle handle); +XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *config); /** - * @brief Wait for the tracker to finish, should be called after XGTrackerRun(). + * @brief Wait for the tracker to finish, should be called after XGTrackerRun(). This + * function will block until the tracker task is finished or timeout is reached. * * @param handle The handle to the tracker. * @param config JSON encoded configuration. No argument is required yet, preserved for @@ -1618,11 +1612,12 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle); * * @return 0 for success, -1 for failure. */ -XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config); +XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config); /** - * @brief Free a tracker instance. XGTrackerWait() is called internally. If the tracker - * cannot close properly, manual interruption is required. + * @brief Free a tracker instance. This should be called after XGTrackerWaitFor(). If the + * tracker is not properly waited, this function will shutdown all connections with + * the tracker, potentially leading to undefined behavior. * * @param handle The handle to the tracker. * @@ -1630,129 +1625,128 @@ XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config); */ XGB_DLL int XGTrackerFree(TrackerHandle handle); -/*! - * \brief Initialize the collective communicator. +/** + * @brief Initialize the collective communicator. * * Currently the communicator API is experimental, function signatures may change in the future * without notice. * - * Call this once before using anything. - * - * The additional configuration is not required. Usually the communicator will detect settings - * from environment variables. + * Call this once in the worker process before using anything. Please make sure + * XGCommunicatorFinalize() is called after use. The initialized commuicator is a global + * thread-local variable. * - * \param config JSON encoded configuration. Accepted JSON keys are: - * - xgboost_communicator: The type of the communicator. Can be set as an environment variable. + * @param config JSON encoded configuration. Accepted JSON keys are: + * - dmlc_communicator: The type of the communicator, this should match the tracker type. * * rabit: Use Rabit. This is the default if the type is unspecified. * * federated: Use the gRPC interface for Federated Learning. - * Only applicable to the Rabit communicator (these are case-sensitive): - * - rabit_tracker_uri: Hostname of the tracker. - * - rabit_tracker_port: Port number of the tracker. - * - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment. - * - rabit_world_size: Total number of workers. - * - rabit_timeout: Enable timeout. - * - rabit_timeout_sec: Timeout in seconds. - * Only applicable to the Rabit communicator (these are case-sensitive, and can be set as - * environment variables): - * - DMLC_TRACKER_URI: Hostname of the tracker. - * - DMLC_TRACKER_PORT: Port number of the tracker. - * - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment. - * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. - * - dmlc_nccl_path: The path to NCCL shared object. Only used if XGBoost is compiled with - * `USE_DLOPEN_NCCL`. - * Only applicable to the Federated communicator (use upper case for environment variables, use + * + * Only applicable to the `rabit` communicator: + * - dmlc_tracker_uri: Hostname or IP address of the tracker. + * - dmlc_tracker_port: Port number of the tracker. + * - dmlc_task_id: ID of the current task, can be used to obtain deterministic rank assignment. + * - dmlc_retry: The number of retries for connection failure. + * - dmlc_timeout: Timeout in seconds. + * - dmlc_nccl_path: Path to the nccl shared library `libnccl.so`. + * + * Only applicable to the `federated` communicator (use upper case for environment variables, use * lower case for runtime configuration): * - federated_server_address: Address of the federated server. * - federated_world_size: Number of federated workers. * - federated_rank: Rank of the current worker. - * - federated_server_cert: Server certificate file path. Only needed for the SSL mode. - * - federated_client_key: Client key file path. Only needed for the SSL mode. - * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. - * \return 0 for success, -1 for failure. + * - federated_server_cert_path: Server certificate file path. Only needed for the SSL mode. + * - federated_client_key_path: Client key file path. Only needed for the SSL mode. + * - federated_client_cert_path: Client certificate file path. Only needed for the SSL mode. + * + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorInit(char const* config); -/*! - * \brief Finalize the collective communicator. +/** + * @brief Finalize the collective communicator. * - * Call this function after you finished all jobs. + * Call this function after you have finished all jobs. * - * \return 0 for success, -1 for failure. + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorFinalize(void); -/*! - * \brief Get rank of current process. +/** + * @brief Get rank of the current process. * - * \return Rank of the worker. + * @return Rank of the worker. */ XGB_DLL int XGCommunicatorGetRank(void); -/*! - * \brief Get total number of processes. +/** + * @brief Get the total number of processes. * - * \return Total world size. + * @return Total world size. */ XGB_DLL int XGCommunicatorGetWorldSize(void); -/*! - * \brief Get if the communicator is distributed. +/** + * @brief Get if the communicator is distributed. * - * \return True if the communicator is distributed. + * @return True if the communicator is distributed. */ XGB_DLL int XGCommunicatorIsDistributed(void); -/*! - * \brief Print the message to the communicator. +/** + * @brief Print the message to the tracker. * - * This function can be used to communicate the information of the progress to the user who monitors - * the communicator. + * This function can be used to communicate the information of the progress to the user + * who monitors the tracker. * - * \param message The message to be printed. - * \return 0 for success, -1 for failure. + * @param message The message to be printed. + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorPrint(char const *message); -/*! - * \brief Get the name of the processor. +/** + * @brief Get the name of the processor. * - * \param name_str Pointer to received returned processor name. - * \return 0 for success, -1 for failure. + * @param name_str Pointer to received returned processor name. + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str); -/*! - * \brief Broadcast a memory region to all others from root. This function is NOT thread-safe. +/** + * @brief Broadcast a memory region to all others from root. This function is NOT + * thread-safe. * * Example: - * \code + * @code * int a = 1; * Broadcast(&a, sizeof(a), root); - * \endcode + * @endcode * - * \param send_receive_buffer Pointer to the send or receive buffer. - * \param size Size of the data. - * \param root The process rank to broadcast from. - * \return 0 for success, -1 for failure. + * @param send_receive_buffer Pointer to the send or receive buffer. + * @param size Size of the data in bytes. + * @param root The process rank to broadcast from. + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root); -/*! - * \brief Perform in-place allreduce. This function is NOT thread-safe. +/** + * @brief Perform in-place allreduce. This function is NOT thread-safe. * * Example Usage: the following code gives sum of the result - * \code - * vector data(10); + * @code + * enum class Op { + * kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5 + * }; + * std::vector data(10); * ... - * Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum); + * Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum); * ... - * \endcode + * @endcode - * \param send_receive_buffer Buffer for both sending and receiving data. - * \param count Number of elements to be reduced. - * \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h. - * \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h. - * \return 0 for success, -1 for failure. + * @param send_receive_buffer Buffer for both sending and receiving data. + * @param count Number of elements to be reduced. + * @param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h. + * @param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h. + * + * @return 0 for success, -1 for failure. */ XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op); diff --git a/include/xgboost/collective/result.h b/include/xgboost/collective/result.h index 919d3a902298..c126366a07a0 100644 --- a/include/xgboost/collective/result.h +++ b/include/xgboost/collective/result.h @@ -3,13 +3,11 @@ */ #pragma once -#include - -#include // for unique_ptr -#include // for stringstream -#include // for stack -#include // for string -#include // for move +#include // for int32_t +#include // for unique_ptr +#include // for string +#include // for error_code +#include // for move namespace xgboost::collective { namespace detail { @@ -48,48 +46,18 @@ struct ResultImpl { return cur_eq; } - [[nodiscard]] std::string Report() { - std::stringstream ss; - ss << "\n- " << this->message; - if (this->errc != std::error_code{}) { - ss << " system error:" << this->errc.message(); - } + [[nodiscard]] std::string Report() const; + [[nodiscard]] std::error_code Code() const; - auto ptr = prev.get(); - while (ptr) { - ss << "\n- "; - ss << ptr->message; + void Concat(std::unique_ptr rhs); +}; - if (ptr->errc != std::error_code{}) { - ss << " " << ptr->errc.message(); - } - ptr = ptr->prev.get(); - } +#if (!defined(__GNUC__) && !defined(__clang__)) || defined(__MINGW32__) +#define __builtin_FILE() nullptr +#define __builtin_LINE() (-1) +#endif - return ss.str(); - } - [[nodiscard]] auto Code() const { - // Find the root error. - std::stack stack; - auto ptr = this; - while (ptr) { - stack.push(ptr); - if (ptr->prev) { - ptr = ptr->prev.get(); - } else { - break; - } - } - while (!stack.empty()) { - auto frame = stack.top(); - stack.pop(); - if (frame->errc != std::error_code{}) { - return frame->errc; - } - } - return std::error_code{}; - } -}; +std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line); } // namespace detail /** @@ -131,8 +99,21 @@ struct Result { } return *impl_ == *that.impl_; } + + friend Result operator+(Result&& lhs, Result&& rhs); }; +[[nodiscard]] inline Result operator+(Result&& lhs, Result&& rhs) { + if (lhs.OK()) { + return std::forward(rhs); + } + if (rhs.OK()) { + return std::forward(lhs); + } + lhs.impl_->Concat(std::move(rhs.impl_)); + return std::forward(lhs); +} + /** * @brief Return success. */ @@ -140,38 +121,43 @@ struct Result { /** * @brief Return failure. */ -[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; } +[[nodiscard]] inline auto Fail(std::string msg, char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line)}; +} /** * @brief Return failure with `errno`. */ -[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) { - return Result{std::move(msg), std::move(errc)}; +[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, + char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc)}; } /** * @brief Return failure with a previous error. */ -[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) { - return Result{std::move(msg), std::forward(prev)}; +[[nodiscard]] inline auto Fail(std::string msg, Result&& prev, char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line), std::forward(prev)}; } /** * @brief Return failure with a previous error and a new `errno`. */ -[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) { - return Result{std::move(msg), std::move(errc), std::forward(prev)}; +[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev, + char const* file = __builtin_FILE(), + std::int32_t line = __builtin_LINE()) { + return Result{detail::MakeMsg(std::move(msg), file, line), std::move(errc), + std::forward(prev)}; } // We don't have monad, a simple helper would do. template -[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) { +[[nodiscard]] std::enable_if_t, Result> operator<<(Result&& r, Fn&& fn) { if (!r.OK()) { return std::forward(r); } return fn(); } -inline void SafeColl(Result const& rc) { - if (!rc.OK()) { - LOG(FATAL) << rc.Report(); - } -} +void SafeColl(Result const& rc); } // namespace xgboost::collective diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index 84453411046e..c5dd977f6255 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2022-2023, XGBoost Contributors + * Copyright (c) 2022-2024, XGBoost Contributors */ #pragma once @@ -12,11 +12,14 @@ #include // std::size_t #include // std::int32_t, std::uint16_t #include // memset -#include // std::numeric_limits #include // std::string #include // std::error_code, std::system_category #include // std::swap +#if defined(__linux__) +#include // for TIOCOUTQ, FIONREAD +#endif // defined(__linux__) + #if !defined(xgboost_IS_MINGW) #if defined(__MINGW32__) @@ -125,6 +128,21 @@ inline std::int32_t CloseSocket(SocketT fd) { #endif } +inline std::int32_t ShutdownSocket(SocketT fd) { +#if defined(_WIN32) + auto rc = shutdown(fd, SD_BOTH); + if (rc != 0 && LastError() == WSANOTINITIALISED) { + return 0; + } +#else + auto rc = shutdown(fd, SHUT_RDWR); + if (rc != 0 && LastError() == ENOTCONN) { + return 0; + } +#endif + return rc; +} + inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) { #ifdef _WIN32 return errsv == WSAEWOULDBLOCK; @@ -305,7 +323,8 @@ class TCPSocket { std::int32_t domain; socklen_t len = sizeof(domain); xgboost_CHECK_SYS_CALL( - getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast(&domain), &len), 0); + getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast(&domain), &len), + 0); return ret_iafamily(domain); #else struct sockaddr sa; @@ -412,6 +431,35 @@ class TCPSocket { return Success(); } + [[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) { + socklen_t optlen; + auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast(n_bytes), + &optlen); + if (rc != 0 || optlen != sizeof(std::int32_t)) { + return system::FailWithCode("getsockopt"); + } + return Success(); + } + [[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) { + socklen_t optlen; + auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast(n_bytes), + &optlen); + if (rc != 0 || optlen != sizeof(std::int32_t)) { + return system::FailWithCode("getsockopt"); + } + return Success(); + } +#if defined(__linux__) + [[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const { + return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success() + : system::FailWithCode("ioctl"); + } + [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const { + return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success() + : system::FailWithCode("ioctl"); + } +#endif // defined(__linux__) + [[nodiscard]] Result SetKeepAlive() { std::int32_t keepalive = 1; auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&keepalive), @@ -422,10 +470,9 @@ class TCPSocket { return Success(); } - [[nodiscard]] Result SetNoDelay() { - std::int32_t tcp_no_delay = 1; - auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&tcp_no_delay), - sizeof(tcp_no_delay)); + [[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) { + auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&no_delay), + sizeof(no_delay)); if (rc != 0) { return system::FailWithCode("Failed to set TCP no delay."); } @@ -436,41 +483,62 @@ class TCPSocket { * \brief Accept new connection, returns a new TCP socket for the new connection. */ TCPSocket Accept() { - HandleT newfd = accept(Handle(), nullptr, nullptr); + SockAddress addr; + TCPSocket newsock; + auto rc = this->Accept(&newsock, &addr); + SafeColl(rc); + return newsock; + } + + [[nodiscard]] Result Accept(TCPSocket *out, SockAddress *addr) { #if defined(_WIN32) auto interrupt = WSAEINTR; #else auto interrupt = EINTR; #endif - if (newfd == InvalidSocket() && system::LastError() != interrupt) { - system::ThrowAtError("accept"); + if (this->Domain() == SockDomain::kV4) { + struct sockaddr_in caddr; + socklen_t caddr_len = sizeof(caddr); + HandleT newfd = accept(Handle(), reinterpret_cast(&caddr), &caddr_len); + if (newfd == InvalidSocket() && system::LastError() != interrupt) { + return system::FailWithCode("Failed to accept."); + } + *addr = SockAddress{SockAddrV4{caddr}}; + *out = TCPSocket{newfd}; + } else { + struct sockaddr_in6 caddr; + socklen_t caddr_len = sizeof(caddr); + HandleT newfd = accept(Handle(), reinterpret_cast(&caddr), &caddr_len); + if (newfd == InvalidSocket() && system::LastError() != interrupt) { + return system::FailWithCode("Failed to accept."); + } + *addr = SockAddress{SockAddrV6{caddr}}; + *out = TCPSocket{newfd}; } - TCPSocket newsock{newfd}; - return newsock; - } - - [[nodiscard]] Result Accept(TCPSocket *out, SockAddrV4 *addr) { - struct sockaddr_in caddr; - socklen_t caddr_len = sizeof(caddr); - HandleT newfd = accept(Handle(), reinterpret_cast(&caddr), &caddr_len); - if (newfd == InvalidSocket()) { - return system::FailWithCode("Failed to accept."); + // On MacOS, this is automatically set to async socket if the parent socket is async + // We make sure all socket are blocking by default. + // + // On Windows, a closed socket is returned during shutdown. We guard against it when + // setting non-blocking. + if (!out->IsClosed()) { + return out->NonBlocking(false); } - *addr = SockAddrV4{caddr}; - *out = TCPSocket{newfd}; return Success(); } ~TCPSocket() { if (!IsClosed()) { - Close(); + auto rc = this->Close(); + if (!rc.OK()) { + LOG(WARNING) << rc.Report(); + } } } TCPSocket(TCPSocket const &that) = delete; TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); } TCPSocket &operator=(TCPSocket const &that) = delete; - TCPSocket &operator=(TCPSocket &&that) { + TCPSocket &operator=(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); return *this; } @@ -479,36 +547,49 @@ class TCPSocket { */ [[nodiscard]] HandleT const &Handle() const { return handle_; } /** - * \brief Listen to incoming requests. Should be called after bind. + * @brief Listen to incoming requests. Should be called after bind. */ - void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); } + [[nodiscard]] Result Listen(std::int32_t backlog = 16) { + if (listen(handle_, backlog) != 0) { + return system::FailWithCode("Failed to listen."); + } + return Success(); + } /** - * \brief Bind socket to INADDR_ANY, return the port selected by the OS. + * @brief Bind socket to INADDR_ANY, return the port selected by the OS. */ - [[nodiscard]] in_port_t BindHost() { + [[nodiscard]] Result BindHost(std::int32_t* p_out) { + // Use int32 instead of in_port_t for consistency. We take port as parameter from + // users using other languages, the port is usually stored and passed around as int. if (Domain() == SockDomain::kV6) { auto addr = SockAddrV6::InaddrAny(); auto handle = reinterpret_cast(&addr.Handle()); - xgboost_CHECK_SYS_CALL( - bind(handle_, handle, sizeof(std::remove_reference_t)), 0); + if (bind(handle_, handle, sizeof(std::remove_reference_t)) != 0) { + return system::FailWithCode("bind failed."); + } sockaddr_in6 res_addr; socklen_t addrlen = sizeof(res_addr); - xgboost_CHECK_SYS_CALL( - getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0); - return ntohs(res_addr.sin6_port); + if (getsockname(handle_, reinterpret_cast(&res_addr), &addrlen) != 0) { + return system::FailWithCode("getsockname failed."); + } + *p_out = ntohs(res_addr.sin6_port); } else { auto addr = SockAddrV4::InaddrAny(); auto handle = reinterpret_cast(&addr.Handle()); - xgboost_CHECK_SYS_CALL( - bind(handle_, handle, sizeof(std::remove_reference_t)), 0); + if (bind(handle_, handle, sizeof(std::remove_reference_t)) != 0) { + return system::FailWithCode("bind failed."); + } sockaddr_in res_addr; socklen_t addrlen = sizeof(res_addr); - xgboost_CHECK_SYS_CALL( - getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0); - return ntohs(res_addr.sin_port); + if (getsockname(handle_, reinterpret_cast(&res_addr), &addrlen) != 0) { + return system::FailWithCode("getsockname failed."); + } + *p_out = ntohs(res_addr.sin_port); } + + return Success(); } [[nodiscard]] auto Port() const { @@ -554,45 +635,47 @@ class TCPSocket { } /** - * \brief Send data, without error then all data should be sent. + * @brief Send data, without error then all data should be sent. */ - [[nodiscard]] auto SendAll(void const *buf, std::size_t len) { + [[nodiscard]] Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent) { char const *_buf = reinterpret_cast(buf); - std::size_t ndone = 0; + std::size_t &ndone = *n_sent; + ndone = 0; while (ndone < len) { ssize_t ret = send(handle_, _buf, len - ndone, 0); if (ret == -1) { if (system::LastErrorWouldBlock()) { - return ndone; + return Success(); } - system::ThrowAtError("send"); + return system::FailWithCode("send"); } _buf += ret; ndone += ret; } - return ndone; + return Success(); } /** - * \brief Receive data, without error then all data should be received. + * @brief Receive data, without error then all data should be received. */ - [[nodiscard]] auto RecvAll(void *buf, std::size_t len) { + [[nodiscard]] Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv) { char *_buf = reinterpret_cast(buf); - std::size_t ndone = 0; + std::size_t &ndone = *n_recv; + ndone = 0; while (ndone < len) { ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL); if (ret == -1) { if (system::LastErrorWouldBlock()) { - return ndone; + return Success(); } - system::ThrowAtError("recv"); + return system::FailWithCode("recv"); } if (ret == 0) { - return ndone; + return Success(); } _buf += ret; ndone += ret; } - return ndone; + return Success(); } /** * \brief Send data using the socket @@ -621,26 +704,49 @@ class TCPSocket { */ std::size_t Send(StringView str); /** - * \brief Receive string, format is matched with the Python socket wrapper in RABIT. + * @brief Receive string, format is matched with the Python socket wrapper in RABIT. */ - std::size_t Recv(std::string *p_str); + [[nodiscard]] Result Recv(std::string *p_str); /** - * \brief Close the socket, called automatically in destructor if the socket is not closed. + * @brief Close the socket, called automatically in destructor if the socket is not closed. */ - void Close() { + [[nodiscard]] Result Close() { if (InvalidSocket() != handle_) { -#if defined(_WIN32) auto rc = system::CloseSocket(handle_); +#if defined(_WIN32) // it's possible that we close TCP sockets after finalizing WSA due to detached thread. if (rc != 0 && system::LastError() != WSANOTINITIALISED) { - system::ThrowAtError("close", rc); + return system::FailWithCode("Failed to close the socket."); } #else - xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0); + if (rc != 0) { + return system::FailWithCode("Failed to close the socket."); + } #endif handle_ = InvalidSocket(); } + return Success(); } + /** + * @brief Call shutdown on the socket. + */ + [[nodiscard]] Result Shutdown() { + if (this->IsClosed()) { + return Success(); + } + auto rc = system::ShutdownSocket(this->Handle()); +#if defined(_WIN32) + // Windows cannot shutdown a socket if it's not connected. + if (rc == -1 && system::LastError() == WSAENOTCONN) { + return Success(); + } +#endif + if (rc != 0) { + return system::FailWithCode("Failed to shutdown socket."); + } + return Success(); + } + /** * \brief Create a TCP socket on specified domain. */ diff --git a/include/xgboost/data.h b/include/xgboost/data.h index c449164ca572..05e2cb0080f0 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -137,14 +136,6 @@ class MetaInfo { * \param fo The output stream. */ void SaveBinary(dmlc::Stream* fo) const; - /*! - * \brief Set information in the meta info. - * \param key The key of the information. - * \param dptr The data pointer of the source array. - * \param dtype The type of the source data. - * \param num Number of elements in the source array. - */ - void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num); /*! * \brief Set information in the meta info with array interface. * \param key The key of the information. @@ -320,7 +311,7 @@ struct BatchParam { struct HostSparsePageView { using Inst = common::Span; - common::Span offset; + common::Span offset; common::Span data; Inst operator[](size_t i) const { @@ -338,7 +329,7 @@ struct HostSparsePageView { class SparsePage { public: // Offset for each row. - HostDeviceVector offset; + HostDeviceVector offset; /*! \brief the data of the segments */ HostDeviceVector data; @@ -522,10 +513,6 @@ class DMatrix { DMatrix() = default; /*! \brief meta information of the dataset */ virtual MetaInfo& Info() = 0; - virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { - auto const& ctx = *this->Ctx(); - this->Info().SetInfo(ctx, key, dptr, dtype, num); - } virtual void SetInfo(const char* key, std::string const& interface_str) { auto const& ctx = *this->Ctx(); this->Info().SetInfo(ctx, key, StringView{interface_str}); diff --git a/include/xgboost/json.h b/include/xgboost/json.h index 77ca6a510c96..1416b8899785 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -60,9 +60,7 @@ class Value { virtual Json& operator[](int ind); virtual bool operator==(Value const& rhs) const = 0; -#if !defined(__APPLE__) virtual Value& operator=(Value const& rhs) = delete; -#endif // !defined(__APPLE__) std::string TypeStr() const; @@ -105,6 +103,7 @@ class JsonString : public Value { std::string& GetString() & { return str_; } bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override = delete; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kString; @@ -134,6 +133,7 @@ class JsonArray : public Value { std::vector& GetArray() & { return vec_; } bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override = delete; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kArray; @@ -158,6 +158,7 @@ class JsonTypedArray : public Value { JsonTypedArray(JsonTypedArray&& that) noexcept : Value{kind}, vec_{std::move(that.vec_)} {} bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override = delete; void Set(size_t i, T v) { vec_[i] = v; } size_t Size() const { return vec_.size(); } @@ -216,6 +217,7 @@ class JsonObject : public Value { Map& GetObject() & { return object_; } bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override = delete; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kObject; } ~JsonObject() override = default; @@ -249,6 +251,7 @@ class JsonNumber : public Value { Float& GetNumber() & { return number_; } bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override = delete; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kNumber; @@ -287,6 +290,7 @@ class JsonInteger : public Value { : Value{ValueKind::kInteger}, integer_{that.integer_} {} bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override = delete; Int const& GetInteger() && { return integer_; } Int const& GetInteger() const & { return integer_; } @@ -307,6 +311,7 @@ class JsonNull : public Value { void Save(JsonWriter* writer) const override; bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override = delete; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kNull; @@ -336,6 +341,7 @@ class JsonBoolean : public Value { bool& GetBoolean() & { return boolean_; } bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override = delete; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kBoolean; diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 581b2f0804c9..cb7668f4cdd1 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -190,13 +190,14 @@ constexpr auto ArrToTuple(T (&arr)[N]) { // uint division optimization inspired by the CIndexer in cupy. Division operation is // slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 // bit when the index is smaller, then try to avoid division when it's exp of 2. -template +template LINALG_HD auto UnravelImpl(I idx, common::Span shape) { - size_t index[D]{0}; + std::size_t index[D]{0}; static_assert(std::is_signed::value, "Don't change the type without changing the for loop."); + auto const sptr = shape.data(); for (int32_t dim = D; --dim > 0;) { - auto s = static_cast>>(shape[dim]); + auto s = static_cast>>(sptr[dim]); if (s & (s - 1)) { auto t = idx / s; index[dim] = idx - t * s; @@ -295,6 +296,9 @@ class TensorView { using ShapeT = std::size_t[kDim]; using StrideT = ShapeT; + using element_type = T; // NOLINT + using value_type = std::remove_cv_t; // NOLINT + private: StrideT stride_{1}; ShapeT shape_{0}; @@ -314,7 +318,7 @@ class TensorView { } template - LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], + LINALG_HD size_t MakeSliceDim(std::size_t new_shape[D], std::size_t new_stride[D], detail::RangeTag &&range) const { static_assert(new_dim < D); static_assert(old_dim < kDim); @@ -528,9 +532,10 @@ class TensorView { LINALG_HD auto Stride(size_t i) const { return stride_[i]; } /** - * \brief Number of items in the tensor. + * @brief Number of items in the tensor. */ [[nodiscard]] LINALG_HD std::size_t Size() const { return size_; } + [[nodiscard]] bool Empty() const { return Size() == 0; } /** * \brief Whether this is a contiguous array, both C and F contiguous returns true. */ @@ -741,6 +746,14 @@ auto ArrayInterfaceStr(TensorView const &t) { return str; } +template +auto Make1dInterface(T const *vec, std::size_t len) { + Context ctx; + auto t = linalg::MakeTensorView(&ctx, common::Span{vec, len}, len); + auto str = linalg::ArrayInterfaceStr(t); + return str; +} + /** * \brief A tensor storage. To use it for other functionality like slicing one needs to * obtain a view first. This way we can use it on both host and device. @@ -865,7 +878,9 @@ class Tensor { auto HostView() { return this->View(DeviceOrd::CPU()); } auto HostView() const { return this->View(DeviceOrd::CPU()); } - [[nodiscard]] size_t Size() const { return data_.Size(); } + [[nodiscard]] std::size_t Size() const { return data_.Size(); } + [[nodiscard]] bool Empty() const { return Size() == 0; } + auto Shape() const { return common::Span{shape_}; } auto Shape(size_t i) const { return shape_[i]; } diff --git a/include/xgboost/span.h b/include/xgboost/span.h index be8640f73695..7471c2e44ed6 100644 --- a/include/xgboost/span.h +++ b/include/xgboost/span.h @@ -30,9 +30,8 @@ #define XGBOOST_SPAN_H_ #include -#include -#include // size_t +#include // size_t #include #include #include // numeric_limits @@ -73,8 +72,7 @@ #endif // defined(_MSC_VER) && _MSC_VER < 1910 -namespace xgboost { -namespace common { +namespace xgboost::common { #if defined(__CUDA_ARCH__) // Usual logging facility is not available inside device code. @@ -701,14 +699,14 @@ class IterSpan { return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; } [[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT - return {this, 0}; + return it_; } [[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT - return {this, size()}; + return it_ + size(); } }; -} // namespace common -} // namespace xgboost +} // namespace xgboost::common + #if defined(_MSC_VER) &&_MSC_VER < 1910 #undef constexpr diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 4c475da2ea29..32b93c5cacaf 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023 by Contributors + * Copyright 2014-2024, XGBoost Contributors * \file tree_model.h * \brief model structure for tree * \author Tianqi Chen @@ -688,6 +688,9 @@ class RegTree : public Model { } return (*this)[nidx].DefaultLeft(); } + [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const { + return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx); + } [[nodiscard]] bool IsRoot(bst_node_t nidx) const { if (IsMultiTarget()) { return nidx == kRoot; diff --git a/jvm-packages/create_jni.py b/jvm-packages/create_jni.py index 865d07fe8b0f..693546862b63 100755 --- a/jvm-packages/create_jni.py +++ b/jvm-packages/create_jni.py @@ -23,6 +23,7 @@ "USE_NCCL": "OFF", "JVM_BINDINGS": "ON", "LOG_CAPI_INVOCATION": "OFF", + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", } @@ -97,10 +98,6 @@ def native_build(args): args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()] - # if enviorment set rabit_mock - if os.getenv("RABIT_MOCK", None) is not None: - args.append("-DRABIT_MOCK:BOOL=ON") - # if enviorment set GPU_ARCH_FLAG gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None) if gpu_arch_flag is not None: @@ -162,12 +159,6 @@ def native_build(args): maybe_makedirs(output_folder) cp("../lib/" + library_name, output_folder) - print("copying pure-Python tracker") - cp( - "../python-package/xgboost/tracker.py", - "{}/src/main/resources".format(xgboost4j), - ) - print("copying train/test files") maybe_makedirs("{}/src/test/resources".format(xgboost4j_spark)) with cd("../demo/CLI/regression"): diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 23ab70734ac6..17afbe48d2cc 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -33,21 +33,21 @@ UTF-8 1.8 1.8 - 1.18.0 + 1.19.0 4.13.2 3.4.1 3.4.1 2.12.18 2.12 - 3.3.6 + 3.4.0 5 OFF OFF 23.12.1 23.12.1 cuda12 - 3.2.17 - 2.11.0 + 3.2.18 + 2.12.0 @@ -123,7 +123,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.3.0 + 3.4.1 empty-javadoc-jar @@ -152,7 +152,7 @@ org.apache.maven.plugins maven-gpg-plugin - 3.1.0 + 3.2.4 sign-artifacts @@ -166,7 +166,7 @@ org.apache.maven.plugins maven-source-plugin - 3.3.0 + 3.3.1 attach-sources @@ -204,7 +204,7 @@ org.apache.maven.plugins maven-assembly-plugin - 3.6.0 + 3.7.1 jar-with-dependencies @@ -275,7 +275,7 @@ org.apache.maven.plugins maven-deploy-plugin - 3.1.1 + 3.1.2 internal.repo::default::file://${project.build.directory}/mvn-repo @@ -410,7 +410,7 @@ net.alchim31.maven scala-maven-plugin - 4.8.1 + 4.9.0 compile @@ -445,7 +445,7 @@ org.apache.maven.plugins maven-surefire-plugin - 3.2.2 + 3.2.5 false false @@ -473,7 +473,7 @@ net.alchim31.maven scala-maven-plugin - 4.8.1 + 4.9.0 -Xms64m @@ -487,12 +487,17 @@ com.esotericsoftware kryo - 5.5.0 + 5.6.0 + + + com.fasterxml.jackson.core + jackson-databind + 2.14.2 commons-logging commons-logging - 1.3.0 + 1.3.2 org.scalatest diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java index 7a5e3ac68815..99608b927489 100644 --- a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java +++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java @@ -54,9 +54,9 @@ private static class MapFunction private final Map params; private final int round; - private final Map workerEnvs; + private final Map workerEnvs; - public MapFunction(Map params, int round, Map workerEnvs) { + public MapFunction(Map params, int round, Map workerEnvs) { this.params = params; this.round = round; this.workerEnvs = workerEnvs; @@ -174,9 +174,9 @@ public static XGBoostModel train(DataSet> dtrain, int numBoostRound) throws Exception { final RabitTracker tracker = new RabitTracker(dtrain.getExecutionEnvironment().getParallelism()); - if (tracker.start(0L)) { + if (tracker.start()) { return dtrain - .mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerEnvs())) + .mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs())) .reduce((x, y) -> x) .collect() .get(0); diff --git a/jvm-packages/xgboost4j-gpu/pom.xml b/jvm-packages/xgboost4j-gpu/pom.xml index fc55dd15618c..25b44d6b2d2d 100644 --- a/jvm-packages/xgboost4j-gpu/pom.xml +++ b/jvm-packages/xgboost4j-gpu/pom.xml @@ -72,7 +72,7 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.6.2 + 3.6.3 protected true @@ -88,7 +88,7 @@ exec-maven-plugin org.codehaus.mojo - 3.1.0 + 3.2.0 native @@ -113,7 +113,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.3.0 + 3.4.0 diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala index d34802805d79..7e83dc6f17b0 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala @@ -160,7 +160,7 @@ object GpuPreXGBoost extends PreXGBoostProvider { // Check columns and build column data batch val trainingData = GpuUtils.buildColumnDataBatch(feturesCols, - labelName, weightName, marginName, "", castedDF) + labelName, weightName, marginName, groupName, castedDF) // eval map val evalDataMap = evalSets.map { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 5a1af886fc3d..e17c68355c5b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2023 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import scala.collection.mutable import scala.util.Random import scala.collection.JavaConverters._ -import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker} +import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker} import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} @@ -38,21 +38,17 @@ import org.apache.spark.sql.SparkSession /** * Rabit tracker configurations. * - * @param workerConnectionTimeout The timeout for all workers to connect to the tracker. - * Set timeout length to zero to disable timeout. - * Use a finite, non-zero timeout value to prevent tracker from - * hanging indefinitely (in milliseconds) - * (supported by "scala" implementation only.) - * @param hostIp The Rabit Tracker host IP address which is only used for python implementation. + * @param timeout The number of seconds before timeout waiting for workers to connect. and + * for the tracker to shutdown. + * @param hostIp The Rabit Tracker host IP address. * This is only needed if the host IP cannot be automatically guessed. - * @param pythonExec The python executed path for Rabit Tracker, - * which is only used for python implementation. + * @param port The port number for the tracker to listen to. Use a system allocated one by + * default. */ -case class TrackerConf(workerConnectionTimeout: Long, - hostIp: String = "", pythonExec: String = "") +case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0) object TrackerConf { - def apply(): TrackerConf = TrackerConf(0L) + def apply(): TrackerConf = TrackerConf(0) } private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long) @@ -421,7 +417,7 @@ object XGBoost extends XGBoostStageLevel { private def buildDistributedBooster( buildWatches: () => Watches, xgbExecutionParam: XGBoostExecutionParams, - rabitEnv: java.util.Map[String, String], + rabitEnv: java.util.Map[String, Object], obj: ObjectiveTrait, eval: EvalTrait, prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = { @@ -430,7 +426,6 @@ object XGBoost extends XGBoostStageLevel { val taskId = TaskContext.getPartitionId().toString val attempt = TaskContext.get().attemptNumber.toString rabitEnv.put("DMLC_TASK_ID", taskId) - rabitEnv.put("DMLC_NUM_ATTEMPT", attempt) val numRounds = xgbExecutionParam.numRounds val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0 @@ -481,16 +476,15 @@ object XGBoost extends XGBoostStageLevel { } /** visiable for testing */ - private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { - val tracker: IRabitTracker = new PyRabitTracker( - nWorkers, trackerConf.hostIp, trackerConf.pythonExec - ) + private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = { + val tracker: ITracker = new RabitTracker( + nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout) tracker } - private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { + private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = { val tracker = getTracker(nWorkers, trackerConf) - require(tracker.start(trackerConf.workerConnectionTimeout), "FAULT: Failed to start tracker") + require(tracker.start(), "FAULT: Failed to start tracker") tracker } @@ -525,8 +519,8 @@ object XGBoost extends XGBoostStageLevel { // Train for every ${savingRound} rounds and save the partially completed booster val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) val (booster, metrics) = try { - tracker.getWorkerEnvs().putAll(xgbRabitParams) - val rabitEnv = tracker.getWorkerEnvs + tracker.workerArgs().putAll(xgbRabitParams) + val rabitEnv = tracker.workerArgs val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => { var optionWatches: Option[() => Watches] = None @@ -548,11 +542,6 @@ object XGBoost extends XGBoostStageLevel { // of the training task fails the training stage can retry. ResultStage won't retry when // it fails. val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0) - val trackerReturnVal = tracker.waitFor(0L) - logger.info(s"Rabit returns with exit code $trackerReturnVal") - if (trackerReturnVal != 0) { - throw new XGBoostError("XGBoostModel training failed.") - } (booster, metrics) } finally { tracker.stop() diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index b85f4dc8b3ad..fafbd816a265 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -145,28 +145,28 @@ private[spark] trait GeneralParams extends Params { * Rabit tracker configurations. The parameter must be provided as an instance of the * TrackerConf class, which has the following definition: * - * case class TrackerConf(workerConnectionTimeout: Duration, trainingTimeout: Duration, - * trackerImpl: String) + * case class TrackerConf(timeout: Int, hostIp: String, port: Int) * * See below for detailed explanations. * - * - trackerImpl: Select the implementation of Rabit tracker. - * default: "python" - * - * Choice between "python" or "scala". The former utilizes the Java wrapper of the - * Python Rabit tracker (in dmlc_core), and does not support timeout settings. - * The "scala" version removes Python components, and fully supports timeout settings. - * - * - workerConnectionTimeout: the maximum wait time for all workers to connect to the tracker. - * default: 0 millisecond (no timeout) + * - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds) + * default: 0 (no timeout) * + * Timeout for constructing the communication group and waiting for the tracker to + * shutdown when it's instructed to, doesn't apply to communication when tracking + * is running. * The timeout value should take the time of data loading and pre-processing into account, - * due to the lazy execution of Spark's operations. Alternatively, you may force Spark to + * due to potential lazy execution. Alternatively, you may force Spark to * perform data transformation before calling XGBoost.train(), so that this timeout truly * reflects the connection delay. Set a reasonable timeout value to prevent model * training/testing from hanging indefinitely, possible due to network issues. * Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf). - * Ignored if the tracker implementation is "python". + * + * - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP + * cannot be automatically guessed. + * + * - port : The port number for the tracker to listen to. Use a system allocated one by + * default. */ final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations") setDefault(trackerConf, TrackerConf()) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala index 5445cd1bf6a1..108053af5d76 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,8 +20,7 @@ import java.util.concurrent.LinkedBlockingDeque import scala.util.Random -import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker => PyRabitTracker} -import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus +import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker} import ml.dmlc.xgboost4j.scala.DMatrix import org.scalatest.funsuite.AnyFunSuite @@ -33,50 +32,6 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { xgbParamsFactory.buildXGBRuntimeParams } - test("Customize host ip and python exec for Rabit tracker") { - val hostIp = "192.168.22.111" - val pythonExec = "/usr/bin/python3" - - val paramMap = Map( - "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, hostIp)) - val xgbExecParams = getXGBoostExecutionParams(paramMap) - val tracker = XGBoost.getTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) - tracker match { - case pyTracker: PyRabitTracker => - val cmd = pyTracker.getRabitTrackerCommand - assert(cmd.contains(hostIp)) - assert(cmd.startsWith("python")) - case _ => assert(false, "expected python tracker implementation") - } - - val paramMap1 = Map( - "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, "", pythonExec)) - val xgbExecParams1 = getXGBoostExecutionParams(paramMap1) - val tracker1 = XGBoost.getTracker(xgbExecParams1.numWorkers, xgbExecParams1.trackerConf) - tracker1 match { - case pyTracker: PyRabitTracker => - val cmd = pyTracker.getRabitTrackerCommand - assert(cmd.startsWith(pythonExec)) - assert(!cmd.contains(hostIp)) - case _ => assert(false, "expected python tracker implementation") - } - - val paramMap2 = Map( - "num_workers" -> numWorkers, - "tracker_conf" -> TrackerConf(0L, hostIp, pythonExec)) - val xgbExecParams2 = getXGBoostExecutionParams(paramMap2) - val tracker2 = XGBoost.getTracker(xgbExecParams2.numWorkers, xgbExecParams2.trackerConf) - tracker2 match { - case pyTracker: PyRabitTracker => - val cmd = pyTracker.getRabitTrackerCommand - assert(cmd.startsWith(pythonExec)) - assert(cmd.contains(s" --host-ip=${hostIp}")) - case _ => assert(false, "expected python tracker implementation") - } - } - test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") { /* Deliberately create new instances of SparkContext in each unit test to avoid reusing the @@ -88,9 +43,9 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { */ val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() - val tracker = new PyRabitTracker(numWorkers) - tracker.start(0) - val trackerEnvs = tracker.getWorkerEnvs + val tracker = new RabitTracker(numWorkers) + tracker.start() + val trackerEnvs = tracker. workerArgs val workerCount: Int = numWorkers /* @@ -99,22 +54,8 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { thrown: the thread running the dummy spark job (sparkThread) catches the exception and delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself. - The Java RabitTracker class reacts to exceptions by killing the spawned process running - the Python tracker. If at least one Rabit worker has yet connected to the tracker before - it is killed, the resulted connection failure will trigger the Rabit worker to call - "exit(-1);" in the native C++ code, effectively ending the dummy Spark task. - - In cluster (standalone or YARN) mode of Spark, tasks are run in containers and thus are - isolated from each other. That is, one task calling "exit(-1);" has no effect on other tasks - running in separate containers. However, as unit tests are run in Spark local mode, in which - tasks are executed by threads belonging to the same process, one thread calling "exit(-1);" - ultimately kills the entire process, which also happens to host the Spark driver, causing - the entire Spark application to crash. - To prevent unit tests from crashing, deterministic delays were introduced to make sure that the exception is thrown at last, ideally after all worker connections have been established. - For the same reason, the Java RabitTracker class delays the killing of the Python tracker - process to ensure that pending worker connections are handled. */ val dummyTasks = rdd.mapPartitions { iter => Communicator.init(trackerEnvs) @@ -137,7 +78,32 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { sparkThread.setUncaughtExceptionHandler(tracker) sparkThread.start() - assert(tracker.waitFor(0) != 0) + } + + test("Communicator allreduce works.") { + val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() + val tracker = new RabitTracker(numWorkers) + tracker.start() + val trackerEnvs = tracker.workerArgs + + val workerCount: Int = numWorkers + + rdd.mapPartitions { iter => + val index = iter.next() + Communicator.init(trackerEnvs) + val a = Array(1.0f, 2.0f, 3.0f) + System.out.println(a.mkString(", ")) + val b = Communicator.allReduce(a, Communicator.OpType.SUM) + for (i <- 0 to 2) { + assert(a(i) * workerCount == b(i)) + } + val c = Communicator.allReduce(a, Communicator.OpType.MIN); + for (i <- 0 to 2) { + assert(a(i) == c(i)) + } + Communicator.shutdown() + Iterator(index) + }.collect() } test("should allow the dataframe containing communicator calls to be partially evaluated for" + diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index f187f7394ffa..20a95f2a23e4 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.SparkException import org.apache.spark.ml.param.ParamMap class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll { - test("XGBoost and Spark parameters synchronize correctly") { val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic", "objective_type" -> "classification") @@ -50,7 +49,6 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll { intercept[SparkException] { xgb.fit(trainingDF) } - } test("fail training elegantly with unsupported eval metrics") { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala index 86b82e63ce33..136d39e8bc0f 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala @@ -47,11 +47,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest { val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)) .fit(training) - assert(Communicator.communicatorEnvs.asScala.size > 3) - Communicator.communicatorEnvs.asScala.foreach( item => { - if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1") - }) - val prediction2 = model2.transform(testDF).select("prediction").collect() // check parity w/o rabit cache prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) => @@ -70,10 +65,6 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest { val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1) ).fit(training) - assert(Communicator.communicatorEnvs.asScala.size > 3) - Communicator.communicatorEnvs.asScala.foreach( item => { - if (item._1.toString == "rabit_reduce_ring_mincount") assert(item._2 == "1") - }) // check the equality of single instance prediction val prediction2 = model2.transform(testDF).select("prediction").collect() // check parity w/o rabit cache @@ -81,25 +72,4 @@ class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest { assert(math.abs(p1 - p2) < predictionErrorMin) } } - - test("test rabit timeout fail handle") { - val training = buildDataFrame(Classification.train) - // mock rank 0 failure during 8th allreduce synchronization - Communicator.mockList = Array("0,8,0,0").toList.asJava - - intercept[SparkException] { - new XGBoostClassifier(Map( - "eta" -> "0.1", - "max_depth" -> "10", - "verbosity" -> "1", - "objective" -> "binary:logistic", - "num_round" -> 5, - "num_workers" -> numWorkers, - "rabit_timeout" -> 0)) - .fit(training) - } - - Communicator.mockList = Array.empty.toList.asJava - } - } diff --git a/jvm-packages/xgboost4j-tester/generate_pom.py b/jvm-packages/xgboost4j-tester/generate_pom.py index b9c274c28a4d..ad729b3a64cb 100644 --- a/jvm-packages/xgboost4j-tester/generate_pom.py +++ b/jvm-packages/xgboost4j-tester/generate_pom.py @@ -22,7 +22,7 @@ {scala_version} 3.2.15 {scala_binary_version} - 5.5.0 + 5.6.0 @@ -51,6 +51,11 @@ commons-logging 1.2 + + com.fasterxml.jackson.core + jackson-databind + 2.14.2 + org.scalatest scalatest_${{scala.binary.version}} diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index 7eb18691995b..5a83a400c50b 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -60,7 +60,7 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.6.2 + 3.6.3 protected true @@ -76,7 +76,7 @@ exec-maven-plugin org.codehaus.mojo - 3.1.0 + 3.2.0 native @@ -99,7 +99,7 @@ org.apache.maven.plugins maven-jar-plugin - 3.3.0 + 3.4.1 diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java index 795e7d99e8fe..ee1bc7b4a5a9 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java @@ -7,6 +7,9 @@ import java.util.List; import java.util.Map; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + /** * Collective communicator global class for synchronization. * @@ -30,8 +33,9 @@ public int getOperand() { } public enum DataType implements Serializable { - INT8(0, 1), UINT8(1, 1), INT32(2, 4), UINT32(3, 4), - INT64(4, 8), UINT64(5, 8), FLOAT32(6, 4), FLOAT64(7, 8); + FLOAT16(0, 2), FLOAT32(1, 4), FLOAT64(2, 8), + INT8(4, 1), INT16(5, 2), INT32(6, 4), INT64(7, 8), + UINT8(8, 1), UINT16(9, 2), UINT32(10, 4), UINT64(11, 8); private final int enumOp; private final int size; @@ -56,30 +60,20 @@ private static void checkCall(int ret) throws XGBoostError { } } - // used as way to test/debug passed communicator init parameters - public static Map communicatorEnvs; - public static List mockList = new LinkedList<>(); - /** * Initialize the collective communicator on current working thread. * * @param envs The additional environment variables to pass to the communicator. * @throws XGBoostError */ - public static void init(Map envs) throws XGBoostError { - communicatorEnvs = envs; - String[] args = new String[envs.size() * 2 + mockList.size() * 2]; - int idx = 0; - for (java.util.Map.Entry e : envs.entrySet()) { - args[idx++] = e.getKey(); - args[idx++] = e.getValue(); - } - // pass list of rabit mock strings eg mock=0,1,0,0 - for (String mock : mockList) { - args[idx++] = "mock"; - args[idx++] = mock; + public static void init(Map envs) throws XGBoostError { + ObjectMapper mapper = new ObjectMapper(); + try { + String jconfig = mapper.writeValueAsString(envs); + checkCall(XGBoostJNI.CommunicatorInit(jconfig)); + } catch (JsonProcessingException ex) { + throw new XGBoostError("Failed to read arguments for the communicator.", ex); } - checkCall(XGBoostJNI.CommunicatorInit(args)); } /** diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java similarity index 56% rename from jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java index 984fb80e6dd8..1bfef677d45c 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/IRabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java @@ -1,14 +1,13 @@ package ml.dmlc.xgboost4j.java; import java.util.Map; -import java.util.concurrent.TimeUnit; /** - * Interface for Rabit tracker implementations with three public methods: + * Interface for a tracker implementations with three public methods: * - * - start(timeout): Start the Rabit tracker awaiting for worker connections, with a given - * timeout value (in milliseconds.) - * - getWorkerEnvs(): Return the environment variables needed to initialize Rabit clients. + * - start(timeout): Start the tracker awaiting for worker connections, with a given + * timeout value (in seconds). + * - workerArgs(): Return the arguments needed to initialize Rabit clients. * - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout` * milliseconds. * @@ -21,7 +20,7 @@ * The Rabit tracker handles connections from distributed workers, assigns ranks to workers, and * brokers connections between workers. */ -public interface IRabitTracker extends Thread.UncaughtExceptionHandler { +public interface ITracker extends Thread.UncaughtExceptionHandler { enum TrackerStatus { SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3); @@ -36,9 +35,11 @@ public int getStatusCode() { } } - Map getWorkerEnvs(); - boolean start(long workerConnectionTimeout); - void stop(); - // taskExecutionTimeout has no effect in current version of XGBoost. - int waitFor(long taskExecutionTimeout); + Map workerArgs() throws XGBoostError; + + boolean start() throws XGBoostError; + + void stop() throws XGBoostError; + + void waitFor(long taskExecutionTimeout) throws XGBoostError; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index 0a05b3de0d7f..914a493cc8d1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -1,101 +1,40 @@ package ml.dmlc.xgboost4j.java; -import java.io.*; -import java.util.HashMap; import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; /** * Java implementation of the Rabit tracker to coordinate distributed workers. - * As a wrapper of the Python Rabit tracker, this implementation does not handle timeout for both - * start() and waitFor() methods (i.e., the timeout is infinite.) - * - * For systems lacking Python environment, or for timeout functionality, consider using the Scala - * Rabit tracker (ml.dmlc.xgboost4j.scala.rabit.RabitTracker) which does not depend on Python, and - * provides timeout support. * * The tracker must be started on driver node before running distributed jobs. */ -public class RabitTracker implements IRabitTracker { +public class RabitTracker implements ITracker { // Maybe per tracker logger? private static final Log logger = LogFactory.getLog(RabitTracker.class); - // tracker python file. - private static String tracker_py = null; - private static TrackerProperties trackerProperties = TrackerProperties.getInstance(); - // environment variable to be pased. - private Map envs = new HashMap(); - // number of workers to be submitted. - private int numWorkers; - private String hostIp = ""; - private String pythonExec = ""; - private AtomicReference trackerProcess = new AtomicReference(); - - static { - try { - initTrackerPy(); - } catch (IOException ex) { - logger.error("load tracker library failed."); - logger.error(ex); - } - } - - /** - * Tracker logger that logs output from tracker. - */ - private class TrackerProcessLogger implements Runnable { - public void run() { - - Log trackerProcessLogger = LogFactory.getLog(TrackerProcessLogger.class); - BufferedReader reader = new BufferedReader(new InputStreamReader( - trackerProcess.get().getErrorStream())); - String line; - try { - while ((line = reader.readLine()) != null) { - trackerProcessLogger.info(line); - } - trackerProcess.get().waitFor(); - int exitValue = trackerProcess.get().exitValue(); - if (exitValue != 0) { - trackerProcessLogger.error("Tracker Process ends with exit code " + exitValue); - } else { - trackerProcessLogger.info("Tracker Process ends with exit code " + exitValue); - } - } catch (IOException ex) { - trackerProcessLogger.error(ex.toString()); - } catch (InterruptedException ie) { - // we should not get here as RabitTracker is accessed in the main thread - ie.printStackTrace(); - logger.error("the RabitTracker thread is terminated unexpectedly"); - } - } - } + private long handle = 0; + private Thread tracker_daemon; - private static void initTrackerPy() throws IOException { - try { - tracker_py = NativeLibLoader.createTempFileFromResource("/tracker.py"); - } catch (IOException ioe) { - logger.trace("cannot access tracker python script"); - throw ioe; - } + public RabitTracker(int numWorkers) throws XGBoostError { + this(numWorkers, ""); } - public RabitTracker(int numWorkers) + public RabitTracker(int numWorkers, String hostIp) throws XGBoostError { + this(numWorkers, hostIp, 0, 300); + } + public RabitTracker(int numWorkers, String hostIp, int port, int timeout) throws XGBoostError { if (numWorkers < 1) { throw new XGBoostError("numWorkers must be greater equal to one"); } - this.numWorkers = numWorkers; - } - public RabitTracker(int numWorkers, String hostIp, String pythonExec) - throws XGBoostError { - this(numWorkers); - this.hostIp = hostIp; - this.pythonExec = pythonExec; + long[] out = new long[1]; + XGBoostJNI.checkCall(XGBoostJNI.TrackerCreate(hostIp, numWorkers, port, 0, timeout, out)); + this.handle = out[0]; } public void uncaughtException(Thread t, Throwable e) { @@ -105,7 +44,7 @@ public void uncaughtException(Thread t, Throwable e) { } catch (InterruptedException ex) { logger.error(ex); } finally { - trackerProcess.get().destroy(); + this.tracker_daemon.interrupt(); } } @@ -113,115 +52,43 @@ public void uncaughtException(Thread t, Throwable e) { * Get environments that can be used to pass to worker. * @return The environment settings. */ - public Map getWorkerEnvs() { - return envs; - } - - private void loadEnvs(InputStream ins) throws IOException { - try { - BufferedReader reader = new BufferedReader(new InputStreamReader(ins)); - assert reader.readLine().trim().equals("DMLC_TRACKER_ENV_START"); - String line; - while ((line = reader.readLine()) != null) { - if (line.trim().equals("DMLC_TRACKER_ENV_END")) { - break; - } - String[] sep = line.split("="); - if (sep.length == 2) { - envs.put(sep[0], sep[1]); - } - } - reader.close(); - } catch (IOException ioe){ - logger.error("cannot get runtime configuration from tracker process"); - ioe.printStackTrace(); - throw ioe; - } - } - - /** visible for testing */ - public String getRabitTrackerCommand() { - StringBuilder sb = new StringBuilder(); - if (pythonExec == null || pythonExec.isEmpty()) { - sb.append("python "); - } else { - sb.append(pythonExec + " "); - } - sb.append(" " + tracker_py + " "); - sb.append(" --log-level=DEBUG" + " "); - sb.append(" --num-workers=" + numWorkers + " "); - - // we first check the property then check the parameter - String hostIpFromProperties = trackerProperties.getHostIp(); - if(hostIpFromProperties != null && !hostIpFromProperties.isEmpty()) { - logger.debug("Using provided host-ip: " + hostIpFromProperties + " from properties"); - sb.append(" --host-ip=" + hostIpFromProperties + " "); - } else if (hostIp != null & !hostIp.isEmpty()) { - logger.debug("Using the parametr host-ip: " + hostIp); - sb.append(" --host-ip=" + hostIp + " "); - } - return sb.toString(); - } - - private boolean startTrackerProcess() { + public Map workerArgs() throws XGBoostError { + // fixme: timeout + String[] args = new String[1]; + XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args)); + ObjectMapper mapper = new ObjectMapper(); + TypeReference> typeRef = new TypeReference>() { + }; + Map config; try { - String cmd = getRabitTrackerCommand(); - trackerProcess.set(Runtime.getRuntime().exec(cmd)); - loadEnvs(trackerProcess.get().getInputStream()); - return true; - } catch (IOException ioe) { - ioe.printStackTrace(); - return false; + config = mapper.readValue(args[0], typeRef); + } catch (JsonProcessingException ex) { + throw new XGBoostError("Failed to get worker arguments.", ex); } + return config; } - public void stop() { - if (trackerProcess.get() != null) { - trackerProcess.get().destroy(); - } + public void stop() throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.TrackerFree(this.handle)); } - public boolean start(long timeout) { - if (timeout > 0L) { - logger.warn("Python RabitTracker does not support timeout. " + - "The tracker will wait for all workers to connect indefinitely, unless " + - "it is interrupted manually. Use the Scala RabitTracker for timeout support."); - } + public boolean start() throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle)); + this.tracker_daemon = new Thread(() -> { + try { + XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0)); + } catch (XGBoostError ex) { + logger.error(ex); + return; // exit the thread + } + }); + this.tracker_daemon.setDaemon(true); + this.tracker_daemon.start(); - if (startTrackerProcess()) { - logger.debug("Tracker started, with env=" + envs.toString()); - System.out.println("Tracker started, with env=" + envs.toString()); - // also start a tracker logger - Thread logger_thread = new Thread(new TrackerProcessLogger()); - logger_thread.setDaemon(true); - logger_thread.start(); - return true; - } else { - logger.error("FAULT: failed to start tracker process"); - stop(); - return false; - } + return this.tracker_daemon.isAlive(); } - public int waitFor(long timeout) { - if (timeout > 0L) { - logger.warn("Python RabitTracker does not support timeout. " + - "The tracker will wait for either all workers to finish tasks and send " + - "shutdown signal, or manual interruptions. " + - "Use the Scala RabitTracker for timeout support."); - } - - try { - trackerProcess.get().waitFor(); - int returnVal = trackerProcess.get().exitValue(); - logger.info("Tracker Process ends with exit code " + returnVal); - stop(); - return returnVal; - } catch (InterruptedException e) { - // we should not get here as RabitTracker is accessed in the main thread - e.printStackTrace(); - logger.error("the RabitTracker thread is terminated unexpectedly"); - return TrackerStatus.INTERRUPTED.getStatusCode(); - } + public void waitFor(long timeout) throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, timeout)); } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 2be62a3437d6..71b4ff3f2873 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2023 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 236d53e900a9..b410d2be1d02 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -54,7 +54,7 @@ static void checkCall(int ret) throws XGBoostError { public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); final static native int XGDMatrixCreateFromDataIter(java.util.Iterator iter, - String cache_info, long[] out); + String cache_info, long[] out); public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, int shapeParam, @@ -146,12 +146,24 @@ public final static native int XGBoosterDumpModelExWithFeatures( public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds); // communicator functions - public final static native int CommunicatorInit(String[] args); + public final static native int CommunicatorInit(String args); public final static native int CommunicatorFinalize(); public final static native int CommunicatorPrint(String msg); public final static native int CommunicatorGetRank(int[] out); public final static native int CommunicatorGetWorldSize(int[] out); + // Tracker functions + public final static native int TrackerCreate(String host, int nWorkers, int port, int sortby, long timeout, + long[] out); + + public final static native int TrackerRun(long handle); + + public final static native int TrackerWaitFor(long handle, long timeout); + + public final static native int TrackerWorkerArgs(long handle, long timeout, String[] out); + + public final static native int TrackerFree(long handle); + // Perform Allreduce operation on data in sendrecvbuf. final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count, int enum_dtype, int enum_op); @@ -168,5 +180,4 @@ public final static native int XGDMatrixCreateFromArrayInterfaceColumns( public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features); public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out); - } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java index 501a9cfe186f..e3857a1d4b9e 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/util/UtilUnsafe.java @@ -42,5 +42,4 @@ private static Unsafe getUnsafe() { throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e); } } - } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 50d86c893171..561b97ff3d2c 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -196,5 +196,3 @@ private[scala] object ExternalCheckpointParams { } } } - - diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 332b1a12774b..cfab645ed6bf 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -1,29 +1,31 @@ /** - Copyright (c) 2014-2023 by Contributors - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. + * Copyright 2014-2024, XGBoost Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ #include "./xgboost4j.h" -#include #include #include #include #include +#include // for copy_n #include #include #include -#include +#include // for unique_ptr #include #include #include @@ -61,6 +63,11 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { return JNI_VERSION_1_6; } +namespace { +template +using Deleter = std::function; +} // anonymous namespace + XGB_EXTERN_C int XGBoost4jCallbackDataIterNext( DataIterHandle data_handle, XGBCallbackSetData* set_function, @@ -102,54 +109,70 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext( batch, jenv->GetFieldID(batchClass, "featureValue", "[F")); jint jcols = jenv->GetIntField( batch, jenv->GetFieldID(batchClass, "featureCols", "I")); - XGBoostBatchCSR cbatch; - cbatch.size = jenv->GetArrayLength(joffset) - 1; - cbatch.columns = jcols; - cbatch.offset = reinterpret_cast( - jenv->GetLongArrayElements(joffset, 0)); - if (jlabel != nullptr) { - cbatch.label = jenv->GetFloatArrayElements(jlabel, 0); - CHECK_EQ(jenv->GetArrayLength(jlabel), static_cast(cbatch.size)) - << "batch.label.length must equal batch.numRows()"; - } else { - cbatch.label = nullptr; - } - if (jweight != nullptr) { - cbatch.weight = jenv->GetFloatArrayElements(jweight, 0); - CHECK_EQ(jenv->GetArrayLength(jweight), static_cast(cbatch.size)) - << "batch.weight.length must equal batch.numRows()"; - } else { - cbatch.weight = nullptr; - } - long max_elem = cbatch.offset[cbatch.size]; - cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0); - cbatch.value = jenv->GetFloatArrayElements(jvalue, 0); - - CHECK_EQ(jenv->GetArrayLength(jindex), max_elem) - << "batch.index.length must equal batch.offset.back()"; - CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem) - << "batch.index.length must equal batch.offset.back()"; - // cbatch is ready - CHECK_EQ((*set_function)(set_function_handle, cbatch), 0) - << XGBGetLastError(); - // release the elements. - jenv->ReleaseLongArrayElements( - joffset, reinterpret_cast(cbatch.offset), 0); - jenv->DeleteLocalRef(joffset); - if (jlabel != nullptr) { - jenv->ReleaseFloatArrayElements(jlabel, cbatch.label, 0); - jenv->DeleteLocalRef(jlabel); - } - if (jweight != nullptr) { - jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0); - jenv->DeleteLocalRef(jweight); - } - jenv->ReleaseIntArrayElements(jindex, (jint*) cbatch.index, 0); - jenv->DeleteLocalRef(jindex); - jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0); - jenv->DeleteLocalRef(jvalue); + + std::unique_ptr> cbatch{ + [&] { + auto ptr = new XGBoostBatchCSR; + auto &cbatch = *ptr; + + // Init + cbatch.size = jenv->GetArrayLength(joffset) - 1; + cbatch.columns = jcols; + cbatch.offset = reinterpret_cast(jenv->GetLongArrayElements(joffset, nullptr)); + + if (jlabel != nullptr) { + cbatch.label = jenv->GetFloatArrayElements(jlabel, nullptr); + CHECK_EQ(jenv->GetArrayLength(jlabel), static_cast(cbatch.size)) + << "batch.label.length must equal batch.numRows()"; + } else { + cbatch.label = nullptr; + } + + if (jweight != nullptr) { + cbatch.weight = jenv->GetFloatArrayElements(jweight, nullptr); + CHECK_EQ(jenv->GetArrayLength(jweight), static_cast(cbatch.size)) + << "batch.weight.length must equal batch.numRows()"; + } else { + cbatch.weight = nullptr; + } + + auto max_elem = cbatch.offset[cbatch.size]; + cbatch.index = (int *)jenv->GetIntArrayElements(jindex, nullptr); + cbatch.value = jenv->GetFloatArrayElements(jvalue, nullptr); + CHECK_EQ(jenv->GetArrayLength(jindex), max_elem) + << "batch.index.length must equal batch.offset.back()"; + CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem) + << "batch.index.length must equal batch.offset.back()"; + return ptr; + }(), + [&](XGBoostBatchCSR *ptr) { + auto &cbatch = *ptr; + jenv->ReleaseLongArrayElements(joffset, reinterpret_cast(cbatch.offset), 0); + jenv->DeleteLocalRef(joffset); + + if (jlabel) { + jenv->ReleaseFloatArrayElements(jlabel, cbatch.label, 0); + jenv->DeleteLocalRef(jlabel); + } + if (jweight) { + jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0); + jenv->DeleteLocalRef(jweight); + } + + jenv->ReleaseIntArrayElements(jindex, (jint *)cbatch.index, 0); + jenv->DeleteLocalRef(jindex); + + jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0); + jenv->DeleteLocalRef(jvalue); + + delete ptr; + }}; + + CHECK_EQ((*set_function)(set_function_handle, *cbatch), 0) << XGBGetLastError(); + jenv->DeleteLocalRef(batch); jenv->DeleteLocalRef(batchClass); + ret_value = 1; } else { ret_value = 0; @@ -179,7 +202,7 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError (JNIEnv *jenv, jclass jcls) { jstring jresult = 0; const char* result = XGBGetLastError(); - if (result != NULL) { + if (result) { jresult = jenv->NewStringUTF(result); } return jresult; @@ -193,16 +216,15 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter (JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) { DMatrixHandle result; - const char* cache_info = nullptr; + std::unique_ptr> cache_info; if (jcache_info != nullptr) { - cache_info = jenv->GetStringUTFChars(jcache_info, 0); + cache_info = {jenv->GetStringUTFChars(jcache_info, nullptr), [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jcache_info, ptr); + }}; } - int ret = XGDMatrixCreateFromDataIter( - jiter, XGBoost4jCallbackDataIterNext, cache_info, &result); + int ret = + XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(), &result); JVM_CHECK_CALL(ret); - if (cache_info) { - jenv->ReleaseStringUTFChars(jcache_info, cache_info); - } setHandle(jenv, jout, result); return ret; } @@ -212,20 +234,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro * Method: XGDMatrixCreateFromFile * Signature: (Ljava/lang/String;I[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromFile - (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromFile( + JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { + std::unique_ptr> fname{jenv->GetStringUTFChars(jfname, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfname, ptr); + }}; DMatrixHandle result; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); + int ret = XGDMatrixCreateFromFile(fname.get(), jsilent, &result); JVM_CHECK_CALL(ret); - if (fname) { - jenv->ReleaseStringUTFChars(jfname, fname); - } setHandle(jenv, jout, result); return ret; } namespace { +using JavaIndT = + std::conditional_t::value, std::int32_t, long>; /** * \brief Create from sparse matrix. * @@ -238,20 +262,28 @@ jint MakeJVMSparseInput(JNIEnv *jenv, jlongArray jindptr, jintArray jindices, jf jfloat jmissing, jint jnthread, Fn &&maker, jlongArray jout) { DMatrixHandle result; - jlong *indptr = jenv->GetLongArrayElements(jindptr, nullptr); - jint *indices = jenv->GetIntArrayElements(jindices, nullptr); - jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr); + std::unique_ptr> indptr{jenv->GetLongArrayElements(jindptr, nullptr), + [&](jlong *ptr) { + jenv->ReleaseLongArrayElements(jindptr, ptr, 0); + }}; + std::unique_ptr> indices{jenv->GetIntArrayElements(jindices, nullptr), + [&](jint *ptr) { + jenv->ReleaseIntArrayElements(jindices, ptr, 0); + }}; + std::unique_ptr> data{jenv->GetFloatArrayElements(jdata, nullptr), + [&](jfloat *ptr) { + jenv->ReleaseFloatArrayElements(jdata, ptr, 0); + }}; + bst_ulong nindptr = static_cast(jenv->GetArrayLength(jindptr)); bst_ulong nelem = static_cast(jenv->GetArrayLength(jdata)); std::string sindptr, sindices, sdata; - CHECK_EQ(indptr[nindptr - 1], nelem); + CHECK_EQ(indptr.get()[nindptr - 1], nelem); using IndPtrT = std::conditional_t::value, long, long long>; - using IndT = - std::conditional_t::value, std::int32_t, long>; xgboost::detail::MakeSparseFromPtr( - static_cast(indptr), static_cast(indices), - static_cast(data), nindptr, &sindptr, &sindices, &sdata); + static_cast(indptr.get()), static_cast(indices.get()), + static_cast(data.get()), nindptr, &sindptr, &sindices, &sdata); xgboost::Json jconfig{xgboost::Object{}}; auto missing = static_cast(jmissing); @@ -265,11 +297,6 @@ jint MakeJVMSparseInput(JNIEnv *jenv, jlongArray jindptr, jintArray jindices, jf jint ret = maker(sindptr.c_str(), sindices.c_str(), sdata.c_str(), config.c_str(), &result); JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); - - // Release - jenv->ReleaseLongArrayElements(jindptr, indptr, 0); - jenv->ReleaseIntArrayElements(jindices, indices, 0); - jenv->ReleaseFloatArrayElements(jdata, data, 0); return ret; } } // anonymous namespace @@ -335,37 +362,55 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) { DMatrixHandle result; - jfloat* data = jenv->GetFloatArrayElements(jdata, 0); + std::unique_ptr> data{jenv->GetFloatArrayElements(jdata, 0), [&](jfloat* ptr) { + jenv->ReleaseFloatArrayElements(jdata, ptr, 0); + }}; + bst_ulong nrow = (bst_ulong)jnrow; bst_ulong ncol = (bst_ulong)jncol; - jint ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result); + jint ret = + XGDMatrixCreateFromMat(static_cast(data.get()), nrow, ncol, jmiss, &result); JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); - //release - jenv->ReleaseFloatArrayElements(jdata, data, 0); return ret; } +namespace { +// Workaround int is not the same as jint. For some reason, if constexpr couldn't dispatch +// the following. +template +auto SliaceDMatrixWinWar(DMatrixHandle handle, T *ptr, std::size_t len, DMatrixHandle *result) { + // default to not allowing slicing with group ID specified -- feel free to add if necessary + return XGDMatrixSliceDMatrixEx(handle, ptr, len, result, 0); +} + +template <> +auto SliaceDMatrixWinWar(DMatrixHandle handle, long *ptr, std::size_t len, DMatrixHandle *result) { + std::vector copy(len); + std::copy_n(ptr, len, copy.begin()); + // default to not allowing slicing with group ID specified -- feel free to add if necessary + return XGDMatrixSliceDMatrixEx(handle, copy.data(), len, result, 0); +} +} // namespace + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixSliceDMatrix * Signature: (J[I)J */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMatrix - (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMatrix( + JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) { DMatrixHandle result; - DMatrixHandle handle = (DMatrixHandle) jhandle; - - jint* indexset = jenv->GetIntArrayElements(jindexset, 0); - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset); - - // default to not allowing slicing with group ID specified -- feel free to add if necessary - jint ret = (jint) XGDMatrixSliceDMatrixEx(handle, (int const *)indexset, len, &result, 0); + auto handle = reinterpret_cast(jhandle); + + std::unique_ptr> indexset{jenv->GetIntArrayElements(jindexset, nullptr), + [&](jint *ptr) { + jenv->ReleaseIntArrayElements(jindexset, ptr, 0); + }}; + auto len = static_cast(jenv->GetArrayLength(jindexset)); + auto ret = SliaceDMatrixWinWar(handle, indexset.get(), len, &result); JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); - //release - jenv->ReleaseIntArrayElements(jindexset, indexset, 0); - return ret; } @@ -386,13 +431,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixFree * Method: XGDMatrixSaveBinary * Signature: (JLjava/lang/String;I)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinary - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - int ret = XGDMatrixSaveBinary(handle, fname, jsilent); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinary( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { + DMatrixHandle handle = reinterpret_cast(jhandle); + std::unique_ptr> fname{ + jenv->GetStringUTFChars(jfname, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfname, ptr); + } + }}; + int ret = XGDMatrixSaveBinary(handle, fname.get(), jsilent); JVM_CHECK_CALL(ret); - if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); return ret; } @@ -401,19 +450,23 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinar * Method: XGDMatrixSetFloatInfo * Signature: (JLjava/lang/String;[F)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatInfo - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatInfo( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{ + jenv->GetStringUTFChars(jfield, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + } + }}; + std::unique_ptr> array{jenv->GetFloatArrayElements(jarray, nullptr), + [&](jfloat *ptr) { + jenv->ReleaseFloatArrayElements(jarray, ptr, 0); + }}; - jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); - int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len); - JVM_CHECK_CALL(ret); - //release - if (field) jenv->ReleaseStringUTFChars(jfield, field); - jenv->ReleaseFloatArrayElements(jarray, array, 0); - return ret; + auto str = xgboost::linalg::Make1dInterface(array.get(), len); + return XGDMatrixSetInfoFromInterface(handle, field.get(), str.c_str()); } /* @@ -423,17 +476,20 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); - jint* array = jenv->GetIntArrayElements(jarray, NULL); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{ + jenv->GetStringUTFChars(jfield, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + } + }}; + std::unique_ptr> array{jenv->GetIntArrayElements(jarray, nullptr), + [&](jint *ptr) { + jenv->ReleaseIntArrayElements(jarray, ptr, 0); + }}; bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); - int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); - JVM_CHECK_CALL(ret); - //release - if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); - jenv->ReleaseIntArrayElements(jarray, array, 0); - - return ret; + auto str = xgboost::linalg::Make1dInterface(array.get(), len); + return XGDMatrixSetInfoFromInterface(handle, field.get(), str.c_str()); } /* @@ -443,13 +499,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{ + jenv->GetStringUTFChars(jfield, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + } + }}; bst_ulong len; float *result; - int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result); + int ret = XGDMatrixGetFloatInfo(handle, field.get(), &len, (const float**) &result); JVM_CHECK_CALL(ret); - if (field) jenv->ReleaseStringUTFChars(jfield, field); jsize jlen = (jsize) len; jfloatArray jarray = jenv->NewFloatArray(jlen); @@ -466,13 +526,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatI */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{ + jenv->GetStringUTFChars(jfield, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + } + }}; bst_ulong len; unsigned int *result; - int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result); + int ret = (jint)XGDMatrixGetUIntInfo(handle, field.get(), &len, (const unsigned int **)&result); JVM_CHECK_CALL(ret); - if (field) jenv->ReleaseStringUTFChars(jfield, field); jsize jlen = (jsize) len; jintArray jarray = jenv->NewIntArray(jlen); @@ -488,7 +552,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntIn */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { - DMatrixHandle handle = (DMatrixHandle) jhandle; + auto handle = reinterpret_cast(jhandle); bst_ulong result[1]; int ret = (jint) XGDMatrixNumRow(handle, result); JVM_CHECK_CALL(ret); @@ -523,11 +587,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate std::vector handles; if (jhandles != nullptr) { size_t len = jenv->GetArrayLength(jhandles); - jlong *cjhandles = jenv->GetLongArrayElements(jhandles, 0); + std::unique_ptr> cjhandles{ + jenv->GetLongArrayElements(jhandles, nullptr), [&](jlong *ptr) { + jenv->ReleaseLongArrayElements(jhandles, ptr, 0); + }}; for (size_t i = 0; i < len; ++i) { - handles.push_back((DMatrixHandle) cjhandles[i]); + handles.push_back(reinterpret_cast(cjhandles.get()[i])); } - jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); } BoosterHandle result; int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result); @@ -541,28 +607,35 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate * Method: XGBoosterFree * Signature: (J)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterFree - (JNIEnv *jenv, jclass jcls, jlong jhandle) { - BoosterHandle handle = (BoosterHandle) jhandle; - return XGBoosterFree(handle); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterFree(JNIEnv *jenv, + jclass jcls, + jlong jhandle) { + auto handle = reinterpret_cast(jhandle); + return XGBoosterFree(handle); } - /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterSetParam * Signature: (JLjava/lang/String;Ljava/lang/String;)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* name = jenv->GetStringUTFChars(jname, 0); - const char* value = jenv->GetStringUTFChars(jvalue, 0); - int ret = XGBoosterSetParam(handle, name, value); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> name{jenv->GetStringUTFChars(jname, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jname, ptr); + } + }}; + std::unique_ptr> value{ + jenv->GetStringUTFChars(jvalue, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jvalue, ptr); + } + }}; + int ret = XGBoosterSetParam(handle, name.get(), value.get()); JVM_CHECK_CALL(ret); - //release - if (name) jenv->ReleaseStringUTFChars(jname, name); - if (value) jenv->ReleaseStringUTFChars(jvalue, value); return ret; } @@ -573,8 +646,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) { - BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle dtrain = (DMatrixHandle) jdtrain; + auto handle = reinterpret_cast(jhandle); + auto dtrain = reinterpret_cast(jdtrain); return XGBoosterUpdateOneIter(handle, jiter, dtrain); } @@ -587,16 +660,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneI JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jint jiter, jfloatArray jgrad, jfloatArray jhess) { API_BEGIN(); - BoosterHandle handle = reinterpret_cast(jhandle); - DMatrixHandle dtrain = reinterpret_cast(jdtrain); + auto handle = reinterpret_cast(jhandle); + auto dtrain = reinterpret_cast(jdtrain); CHECK(handle); CHECK(dtrain); bst_ulong n_samples{0}; JVM_CHECK_CALL(XGDMatrixNumRow(dtrain, &n_samples)); bst_ulong len = static_cast(jenv->GetArrayLength(jgrad)); - jfloat *grad = jenv->GetFloatArrayElements(jgrad, nullptr); - jfloat *hess = jenv->GetFloatArrayElements(jhess, nullptr); + std::unique_ptr> grad{jenv->GetFloatArrayElements(jgrad, nullptr), + [&](jfloat *ptr) { + jenv->ReleaseFloatArrayElements(jgrad, ptr, 0); + }}; + std::unique_ptr> hess{jenv->GetFloatArrayElements(jhess, nullptr), + [&](jfloat *ptr) { + jenv->ReleaseFloatArrayElements(jhess, ptr, 0); + }}; CHECK(grad); CHECK(hess); @@ -608,15 +687,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneI auto ctx = xgboost::detail::BoosterCtx(handle); auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface( - ctx, grad, hess, xgboost::linalg::kC, n_samples, n_targets); - int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast(jiter), s_grad.c_str(), - s_hess.c_str()); - - // release - jenv->ReleaseFloatArrayElements(jgrad, grad, 0); - jenv->ReleaseFloatArrayElements(jhess, hess, 0); - - return ret; + ctx, grad.get(), hess.get(), xgboost::linalg::kC, n_samples, n_targets); + return XGBoosterTrainOneIter(handle, dtrain, static_cast(jiter), s_grad.c_str(), + s_hess.c_str()); API_END(); } @@ -627,30 +700,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneI */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; + auto handle = reinterpret_cast(jhandle); std::vector dmats; std::vector evnames; std::vector evchars; size_t len = static_cast(jenv->GetArrayLength(jdmats)); // put handle from jhandles to chandles - jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0); + std::unique_ptr> cjdmats{ + jenv->GetLongArrayElements(jdmats, nullptr), [&](jlong *ptr) { + jenv->ReleaseLongArrayElements(jdmats, ptr, 0); + }}; for (size_t i = 0; i < len; ++i) { - dmats.push_back((DMatrixHandle) cjdmats[i]); + dmats.push_back(reinterpret_cast(cjdmats.get()[i])); jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i); - const char *s =jenv->GetStringUTFChars(jevname, 0); - evnames.push_back(std::string(s, jenv->GetStringLength(jevname))); - if (s != nullptr) jenv->ReleaseStringUTFChars(jevname, s); + std::unique_ptr> s{jenv->GetStringUTFChars(jevname, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jevname, ptr); + }}; + evnames.emplace_back(s.get(), jenv->GetStringLength(jevname)); } - jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0); + for (size_t i = 0; i < len; ++i) { evchars.push_back(evnames[i].c_str()); } - const char* result; - int ret = XGBoosterEvalOneIter(handle, jiter, - dmlc::BeginPtr(dmats), - dmlc::BeginPtr(evchars), - len, &result); + const char *result; + int ret = XGBoosterEvalOneIter(handle, jiter, dmlc::BeginPtr(dmats), dmlc::BeginPtr(evchars), len, + &result); JVM_CHECK_CALL(ret); jstring jinfo = nullptr; if (result != nullptr) { @@ -667,8 +743,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle dmat = (DMatrixHandle) jdmat; + auto handle = reinterpret_cast(jhandle); + auto dmat = reinterpret_cast(jdmat); bst_ulong len; float *result; int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, @@ -694,7 +770,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type, jfloatArray jmargin, jobjectArray jout) { API_BEGIN(); - BoosterHandle handle = reinterpret_cast(jhandle); + auto handle = reinterpret_cast(jhandle); /** * Create array interface. @@ -730,8 +806,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr if (jmargin) { margin = jenv->GetFloatArrayElements(jmargin, nullptr); JVM_CHECK_CALL(XGProxyDMatrixCreate(&proxy)); - JVM_CHECK_CALL( - XGDMatrixSetFloatInfo(proxy, "base_margin", margin, jenv->GetArrayLength(jmargin))); + auto str = xgboost::linalg::Make1dInterface(margin, jenv->GetArrayLength(jmargin)); + JVM_CHECK_CALL(XGDMatrixSetInfoFromInterface(proxy, "base_margin", str.c_str())); } bst_ulong const *out_shape; @@ -768,17 +844,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr * Method: XGBoosterLoadModel * Signature: (JLjava/lang/String;)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - - int ret = XGBoosterLoadModel(handle, fname); - JVM_CHECK_CALL(ret); - if (fname) { - jenv->ReleaseStringUTFChars(jfname,fname); - } - return ret; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel(JNIEnv *jenv, + jclass jcls, + jlong jhandle, + jstring jfname) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> fname{jenv->GetStringUTFChars(jfname, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfname, ptr); + }}; + return XGBoosterLoadModel(handle, fname.get()); } /* @@ -786,17 +861,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel * Method: XGBoosterSaveModel * Signature: (JLjava/lang/String;)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - - int ret = XGBoosterSaveModel(handle, fname); - JVM_CHECK_CALL(ret); - if (fname) { - jenv->ReleaseStringUTFChars(jfname, fname); - } - return ret; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel(JNIEnv *jenv, + jclass jcls, + jlong jhandle, + jstring jfname) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> fname{ + jenv->GetStringUTFChars(jfname, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfname, ptr); + } + }}; + return XGBoosterSaveModel(handle, fname.get()); } /* @@ -804,15 +880,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel * Method: XGBoosterLoadModelFromBuffer * Signature: (J[B)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModelFromBuffer - (JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) { - BoosterHandle handle = (BoosterHandle) jhandle; - jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0); - int ret = XGBoosterLoadModelFromBuffer( - handle, buffer, jenv->GetArrayLength(jbytes)); - JVM_CHECK_CALL(ret); - jenv->ReleaseByteArrayElements(jbytes, buffer, 0); - return ret; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModelFromBuffer( + JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> buffer{jenv->GetByteArrayElements(jbytes, nullptr), + [&](jbyte *ptr) { + jenv->ReleaseByteArrayElements(jbytes, ptr, 0); + }}; + return XGBoosterLoadModelFromBuffer(handle, buffer.get(), jenv->GetArrayLength(jbytes)); } /* @@ -822,12 +897,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModelToBuffer (JNIEnv * jenv, jclass jcls, jlong jhandle, jstring jformat, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char *format = jenv->GetStringUTFChars(jformat, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> format{ + jenv->GetStringUTFChars(jformat, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jformat, ptr); + } + }}; bst_ulong len = 0; const char *result{nullptr}; - xgboost::Json config {xgboost::Object{}}; - config["format"] = std::string{format}; + xgboost::Json config{xgboost::Object{}}; + config["format"] = std::string{format.get()}; std::string config_str; xgboost::Json::Dump(config, &config_str); @@ -848,13 +928,23 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jstring jformat, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char *fmap = jenv->GetStringUTFChars(jfmap, 0); - const char *format = jenv->GetStringUTFChars(jformat, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> fmap{jenv->GetStringUTFChars(jfmap, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfmap, ptr); + } + }}; + std::unique_ptr> format{ + jenv->GetStringUTFChars(jformat, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jformat, ptr); + } + }}; bst_ulong len = 0; - char **result; + char const **result; - int ret = XGBoosterDumpModelEx(handle, fmap, jwith_stats, format, &len, (const char ***) &result); + int ret = XGBoosterDumpModelEx(handle, fmap.get(), jwith_stats, format.get(), &len, &result); JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; @@ -864,7 +954,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel } jenv->SetObjectArrayElement(jout, 0, jinfos); - if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); return ret; } @@ -876,37 +965,48 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelExWithFeatures (JNIEnv *jenv, jclass jcls, jlong jhandle, jobjectArray jfeature_names, jint jwith_stats, jstring jformat, jobjectArray jout) { - - BoosterHandle handle = (BoosterHandle) jhandle; + auto handle = reinterpret_cast(jhandle); bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeature_names); std::vector feature_names; - std::vector feature_names_char; + std::vector feature_names_char; std::string feature_type_q = "q"; - std::vector feature_types_char; + std::vector feature_types_char; for (bst_ulong i = 0; i < feature_num; ++i) { jstring jfeature_name = (jstring)jenv->GetObjectArrayElement(jfeature_names, i); - const char *s = jenv->GetStringUTFChars(jfeature_name, 0); - feature_names.push_back(std::string(s, jenv->GetStringLength(jfeature_name))); - if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature_name, s); - if (feature_names.back().length() == 0) feature_names.pop_back(); + std::unique_ptr> s{ + jenv->GetStringUTFChars(jfeature_name, nullptr), [&](char const *ptr) { + if (ptr != nullptr) { + jenv->ReleaseStringUTFChars(jfeature_name, ptr); + } + }}; + feature_names.emplace_back(s.get(), jenv->GetStringLength(jfeature_name)); + + if (feature_names.back().length() == 0) { + feature_names.pop_back(); + } } for (size_t i = 0; i < feature_names.size(); ++i) { - feature_names_char.push_back(&feature_names[i][0]); - feature_types_char.push_back(&feature_type_q[0]); + feature_names_char.push_back(feature_names[i].c_str()); + feature_types_char.push_back(feature_type_q.c_str()); } - const char *format = jenv->GetStringUTFChars(jformat, 0); + std::unique_ptr> format{ + jenv->GetStringUTFChars(jformat, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jformat, ptr); + } + }}; bst_ulong len = 0; char **result; - int ret = XGBoosterDumpModelExWithFeatures(handle, feature_num, - (const char **) dmlc::BeginPtr(feature_names_char), - (const char **) dmlc::BeginPtr(feature_types_char), - jwith_stats, format, &len, (const char ***) &result); + int ret = XGBoosterDumpModelExWithFeatures( + handle, feature_num, (const char **)dmlc::BeginPtr(feature_names_char), + (const char **)dmlc::BeginPtr(feature_types_char), jwith_stats, format.get(), &len, + (const char ***)&result); JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; @@ -947,16 +1047,20 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNa * Method: XGBoosterGetAttr * Signature: (JLjava/lang/String;[Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* key = jenv->GetStringUTFChars(jkey, 0); - const char* result; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jobjectArray jout) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> key{jenv->GetStringUTFChars(jkey, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jkey, ptr); + } + }}; + + const char *result; int success; - int ret = XGBoosterGetAttr(handle, key, &result, &success); + int ret = XGBoosterGetAttr(handle, key.get(), &result, &success); JVM_CHECK_CALL(ret); - //release - if (key) jenv->ReleaseStringUTFChars(jkey, key); if (success > 0) { jstring jret = jenv->NewStringUTF(result); @@ -971,17 +1075,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr * Method: XGBoosterSetAttr * Signature: (JLjava/lang/String;Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jstring jvalue) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* key = jenv->GetStringUTFChars(jkey, 0); - const char* value = jenv->GetStringUTFChars(jvalue, 0); - int ret = XGBoosterSetAttr(handle, key, value); - JVM_CHECK_CALL(ret); - //release - if (key) jenv->ReleaseStringUTFChars(jkey, key); - if (value) jenv->ReleaseStringUTFChars(jvalue, value); - return ret; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jstring jvalue) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> key{jenv->GetStringUTFChars(jkey, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jkey, ptr); + } + }}; + std::unique_ptr> value{ + jenv->GetStringUTFChars(jvalue, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jvalue, ptr); + } + }}; + return XGBoosterSetAttr(handle, key.get(), value.get()); } /* @@ -989,9 +1098,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr * Method: XGBoosterGetNumFeature * Signature: (J[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature - (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature( + JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { + auto handle = reinterpret_cast(jhandle); bst_ulong num_feature; int ret = XGBoosterGetNumFeature(handle, &num_feature); JVM_CHECK_CALL(ret); @@ -1002,7 +1111,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound( JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) { - BoosterHandle handle = (BoosterHandle)jhandle; + auto handle = reinterpret_cast(jhandle); std::int32_t n_rounds{0}; auto ret = XGBoosterBoostedRounds(handle, &n_rounds); JVM_CHECK_CALL(ret); @@ -1014,23 +1123,115 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit - * Signature: ([Ljava/lang/String;)I + * Signature: (Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit - (JNIEnv *jenv, jclass jcls, jobjectArray jargs) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(JNIEnv *jenv, + jclass jcls, + jstring jargs) { xgboost::Json config{xgboost::Object{}}; - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs); - assert(len % 2 == 0); - for (bst_ulong i = 0; i < len / 2; ++i) { - jstring key = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i); - std::string key_str(jenv->GetStringUTFChars(key, 0), jenv->GetStringLength(key)); - jstring value = (jstring)jenv->GetObjectArrayElement(jargs, 2 * i + 1); - std::string value_str(jenv->GetStringUTFChars(value, 0), jenv->GetStringLength(value)); - config[key_str] = xgboost::String(value_str); + std::unique_ptr> args{jenv->GetStringUTFChars(jargs, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jargs, ptr); + } + }}; + return XGCommunicatorInit(args.get()); +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerCreate + * Signature: (Ljava/lang/String;IIIJ[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate( + JNIEnv *jenv, jclass, jstring host, jint n_workers, jint port, jint sortby, jlong timeout, + jlongArray jout) { + using namespace xgboost; // NOLINT + + TrackerHandle handle; + Json config{Object{}}; + std::unique_ptr> p_shost{jenv->GetStringUTFChars(host, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(host, ptr); + }}; + std::string shost{p_shost.get(), + static_cast(jenv->GetStringLength(host))}; + if (!shost.empty()) { + config["host"] = shost; } - std::string json_str; - xgboost::Json::Dump(config, &json_str); - JVM_CHECK_CALL(XGCommunicatorInit(json_str.c_str())); + config["port"] = Integer{static_cast(port)}; + config["n_workers"] = Integer{static_cast(n_workers)}; + config["timeout"] = Integer{static_cast(timeout)}; + config["sortby"] = Integer{static_cast(sortby)}; + config["dmlc_communicator"] = String{"rabit"}; + std::string sconfig = Json::Dump(config); + JVM_CHECK_CALL(XGTrackerCreate(sconfig.c_str(), &handle)); + setHandle(jenv, jout, handle); + + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerRun + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun(JNIEnv *, jclass, + jlong jhandle) { + auto handle = reinterpret_cast(jhandle); + JVM_CHECK_CALL(XGTrackerRun(handle, nullptr)); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerWaitFor + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNIEnv *, jclass, + jlong jhandle, + jlong timeout) { + using namespace xgboost; // NOLINT + + auto handle = reinterpret_cast(jhandle); + Json config{Object{}}; + config["timeout"] = Integer{static_cast(timeout)}; + std::string sconfig = Json::Dump(config); + JVM_CHECK_CALL(XGTrackerWaitFor(handle, sconfig.c_str())); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerWorkerArgs + * Signature: (JJ[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs( + JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) { + using namespace xgboost; // NOLINT + + Json config{Object{}}; + config["timeout"] = Integer{static_cast(timeout)}; + std::string sconfig = Json::Dump(config); + auto handle = reinterpret_cast(jhandle); + char const *args; + JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args)); + auto jargs = Json::Load(StringView{args}); + + jstring jret = jenv->NewStringUTF(args); + jenv->SetObjectArrayElement(jout, 0, jret); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerFree + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree(JNIEnv *, jclass, + jlong jhandle) { + auto handle = reinterpret_cast(jhandle); + JVM_CHECK_CALL(XGTrackerFree(handle)); return 0; } @@ -1039,8 +1240,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit * Method: CommunicatorFinalize * Signature: ()I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize - (JNIEnv *jenv, jclass jcls) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinalize(JNIEnv *, + jclass) { JVM_CHECK_CALL(XGCommunicatorFinalize()); return 0; } @@ -1050,12 +1251,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinali * Method: CommunicatorPrint * Signature: (Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint - (JNIEnv *jenv, jclass jcls, jstring jmsg) { - std::string str(jenv->GetStringUTFChars(jmsg, 0), - jenv->GetStringLength(jmsg)); - JVM_CHECK_CALL(XGCommunicatorPrint(str.c_str())); - return 0; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint(JNIEnv *jenv, + jclass jcls, + jstring jmsg) { + std::unique_ptr> msg{jenv->GetStringUTFChars(jmsg, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jmsg, ptr); + } + }}; + std::string str(msg.get(), jenv->GetStringLength(jmsg)); + return XGCommunicatorPrint(str.c_str()); } /* @@ -1124,11 +1330,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM * Method: XGQuantileDMatrixCreateFromCallback * Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback - (JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf, jlongArray jout) { - char const *conf = jenv->GetStringUTFChars(jconf, 0); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback( + JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf, + jlongArray jout) { + std::unique_ptr> conf{jenv->GetStringUTFChars(jconf, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jconf, ptr); + }}; return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter, - conf, jout); + conf.get(), jout); } /* @@ -1136,18 +1346,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixC * Method: XGDMatrixSetInfoFromInterface * Signature: (JLjava/lang/String;Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); - const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, 0); - - int ret = XGDMatrixSetInfoFromInterface(handle, field, cjson_columns); - JVM_CHECK_CALL(ret); - //release - if (field) jenv->ReleaseStringUTFChars(jfield, field); - if (cjson_columns) jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns); - return ret; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; + std::unique_ptr> cjson_columns{ + jenv->GetStringUTFChars(jjson_columns, nullptr), [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jjson_columns, ptr); + }}; + + return XGDMatrixSetInfoFromInterface(handle, field.get(), cjson_columns.get()); } /* @@ -1158,7 +1369,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFr JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns (JNIEnv *jenv, jclass jcls, jstring jjson_columns, jfloat jmissing, jint jnthread, jlongArray jout) { DMatrixHandle result; - const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, nullptr); + std::unique_ptr> cjson_columns{ + jenv->GetStringUTFChars(jjson_columns, nullptr), [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jjson_columns, ptr); + }}; xgboost::Json config{xgboost::Object{}}; auto missing = static_cast(jmissing); auto n_threads = static_cast(jnthread); @@ -1166,43 +1380,38 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro config["nthread"] = xgboost::Integer(n_threads); std::string config_str; xgboost::Json::Dump(config, &config_str); - int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns, config_str.c_str(), - &result); + int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns.get(), config_str.c_str(), &result); JVM_CHECK_CALL(ret); - if (cjson_columns) { - jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns); - } - setHandle(jenv, jout, result); return ret; } JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo (JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jvalues) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; int size = jenv->GetArrayLength(jvalues); // tmp storage for java strings std::vector values; for (int i = 0; i < size; i++) { jstring jstr = (jstring)(jenv->GetObjectArrayElement(jvalues, i)); - const char *value = jenv->GetStringUTFChars(jstr, 0); - values.emplace_back(value); - if (value) jenv->ReleaseStringUTFChars(jstr, value); + std::unique_ptr> value{jenv->GetStringUTFChars(jstr, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jstr, ptr); + }}; + values.emplace_back(value.get()); } - std::vector c_values; + std::vector c_values; c_values.resize(size); - std::transform(values.cbegin(), values.cend(), - c_values.begin(), + std::transform(values.cbegin(), values.cend(), c_values.begin(), [](auto const &str) { return str.c_str(); }); - int ret = XGDMatrixSetStrFeatureInfo(handle, field, c_values.data(), size); - JVM_CHECK_CALL(ret); - - if (field) jenv->ReleaseStringUTFChars(jfield, field); - return ret; + return XGDMatrixSetStrFeatureInfo(handle, field.get(), c_values.data(), size); } /* @@ -1210,28 +1419,29 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFea * Method: XGDMatrixGetStrFeatureInfo * Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo - (JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jlongArray joutLenArray, - jobjectArray joutValueArray) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char *field = jenv->GetStringUTFChars(jfield, 0); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo( + JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jlongArray joutLenArray, + jobjectArray joutValueArray) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; bst_ulong out_len = 0; char const **c_out_features; - int ret = XGDMatrixGetStrFeatureInfo(handle, field, &out_len, &c_out_features); + int ret = XGDMatrixGetStrFeatureInfo(handle, field.get(), &out_len, &c_out_features); - jlong jlen = (jlong) out_len; + jlong jlen = (jlong)out_len; jenv->SetLongArrayRegion(joutLenArray, 0, 1, &jlen); - jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), - jenv->NewStringUTF("")); + jobjectArray jinfos = + jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); for (int i = 0; i < jlen; i++) { jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF(c_out_features[i])); } jenv->SetObjectArrayElement(joutValueArray, 0, jinfos); - JVM_CHECK_CALL(ret); - if (field) jenv->ReleaseStringUTFChars(jfield, field); return ret; } @@ -1244,10 +1454,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo( JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jfeatures) { - BoosterHandle handle = (BoosterHandle)jhandle; - - const char *field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeatures); std::vector features; @@ -1255,19 +1467,21 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo( for (bst_ulong i = 0; i < feature_num; ++i) { jstring jfeature = (jstring)jenv->GetObjectArrayElement(jfeatures, i); - const char *s = jenv->GetStringUTFChars(jfeature, 0); - features.push_back(std::string(s, jenv->GetStringLength(jfeature))); - if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature, s); + std::unique_ptr> s{ + jenv->GetStringUTFChars(jfeature, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfeature, ptr); + } + }}; + features.emplace_back(s.get(), jenv->GetStringLength(jfeature)); } for (size_t i = 0; i < features.size(); ++i) { features_char.push_back(features[i].c_str()); } - int ret = XGBoosterSetStrFeatureInfo( - handle, field, dmlc::BeginPtr(features_char), feature_num); - JVM_CHECK_CALL(ret); - return ret; + return XGBoosterSetStrFeatureInfo(handle, field.get(), dmlc::BeginPtr(features_char), + feature_num); } /* @@ -1279,17 +1493,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo( JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle)jhandle; - - const char *field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jout); const char **features; std::vector features_char; - int ret = XGBoosterGetStrFeatureInfo(handle, field, &feature_num, - (const char ***)&features); + int ret = + XGBoosterGetStrFeatureInfo(handle, field.get(), &feature_num, (const char ***)&features); JVM_CHECK_CALL(ret); for (bst_ulong i = 0; i < feature_num; i++) { diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index cc4ad53d4e4c..c8e48cfc9de9 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -306,10 +306,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoo /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit - * Signature: ([Ljava/lang/String;)I + * Signature: (Ljava/lang/String;)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit - (JNIEnv *, jclass, jobjectArray); + (JNIEnv *, jclass, jstring); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI @@ -343,6 +343,46 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetRan JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorGetWorldSize (JNIEnv *, jclass, jintArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerCreate + * Signature: (Ljava/lang/String;IIIJ[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate + (JNIEnv *, jclass, jstring, jint, jint, jint, jlong, jlongArray); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerRun + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerRun + (JNIEnv *, jclass, jlong); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerWaitFor + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor + (JNIEnv *, jclass, jlong, jlong); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerWorkerArgs + * Signature: (JJ[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs + (JNIEnv *, jclass, jlong, jlong, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: TrackerFree + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerFree + (JNIEnv *, jclass, jlong); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorAllreduce diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index d658c55292c4..b6ffe84e30e9 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -298,7 +298,7 @@ public void testCreateFromDenseMatrixRef() throws XGBoostError { @Test public void testTrainWithDenseMatrixRef() throws XGBoostError { - Map rabitEnv = new HashMap<>(); + Map rabitEnv = new HashMap<>(); rabitEnv.put("DMLC_TASK_ID", "0"); Communicator.init(rabitEnv); DMatrix trainMat = null; diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index 0865e756e97b..9a51f59d5781 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -31,31 +31,13 @@ protobuf_generate( PLUGIN "protoc-gen-grpc=\$" PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") -add_library(federated_old_proto STATIC federated.old.proto) -target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) -target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) -xgboost_target_properties(federated_old_proto) - -protobuf_generate( - TARGET federated_old_proto - LANGUAGE cpp - PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") -protobuf_generate( - TARGET federated_old_proto - LANGUAGE grpc - GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc - PLUGIN "protoc-gen-grpc=\$" - PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") - # Wrapper for the gRPC client. add_library(federated_client INTERFACE) -target_sources(federated_client INTERFACE federated_client.h) target_link_libraries(federated_client INTERFACE federated_proto) -target_link_libraries(federated_client INTERFACE federated_old_proto) # Rabit engine for Federated Learning. target_sources( - objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc + objxgboost PRIVATE federated_tracker.cc federated_comm.cc federated_coll.cc ) if(USE_CUDA) target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu) diff --git a/plugin/federated/federated.old.proto b/plugin/federated/federated.old.proto deleted file mode 100644 index 8450659fd180..000000000000 --- a/plugin/federated/federated.old.proto +++ /dev/null @@ -1,81 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -syntax = "proto3"; - -package xgboost.federated; - -service Federated { - rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} - rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {} - rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} - rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} -} - -enum DataType { - INT8 = 0; - UINT8 = 1; - INT32 = 2; - UINT32 = 3; - INT64 = 4; - UINT64 = 5; - FLOAT = 6; - DOUBLE = 7; -} - -enum ReduceOperation { - MAX = 0; - MIN = 1; - SUM = 2; - BITWISE_AND = 3; - BITWISE_OR = 4; - BITWISE_XOR = 5; -} - -message AllgatherRequest { - // An incrementing counter that is unique to each round to operations. - uint64 sequence_number = 1; - int32 rank = 2; - bytes send_buffer = 3; -} - -message AllgatherReply { - bytes receive_buffer = 1; -} - -message AllgatherVRequest { - // An incrementing counter that is unique to each round to operations. - uint64 sequence_number = 1; - int32 rank = 2; - bytes send_buffer = 3; -} - -message AllgatherVReply { - bytes receive_buffer = 1; -} - -message AllreduceRequest { - // An incrementing counter that is unique to each round to operations. - uint64 sequence_number = 1; - int32 rank = 2; - bytes send_buffer = 3; - DataType data_type = 4; - ReduceOperation reduce_operation = 5; -} - -message AllreduceReply { - bytes receive_buffer = 1; -} - -message BroadcastRequest { - // An incrementing counter that is unique to each round to operations. - uint64 sequence_number = 1; - int32 rank = 2; - bytes send_buffer = 3; - // The root rank to broadcast from. - int32 root = 4; -} - -message BroadcastReply { - bytes receive_buffer = 1; -} diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h deleted file mode 100644 index 0122a5cfe153..000000000000 --- a/plugin/federated/federated_client.h +++ /dev/null @@ -1,132 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include -#include -#include - -#include -#include -#include -#include - -namespace xgboost::federated { -/** - * @brief A wrapper around the gRPC client. - */ -class FederatedClient { - public: - FederatedClient(std::string const &server_address, int rank, std::string const &server_cert, - std::string const &client_key, std::string const &client_cert) - : stub_{[&] { - grpc::SslCredentialsOptions options; - options.pem_root_certs = server_cert; - options.pem_private_key = client_key; - options.pem_cert_chain = client_cert; - grpc::ChannelArguments args; - args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - auto channel = - grpc::CreateCustomChannel(server_address, grpc::SslCredentials(options), args); - channel->WaitForConnected( - gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN))); - return Federated::NewStub(channel); - }()}, - rank_{rank} {} - - /** @brief Insecure client for connecting to localhost only. */ - FederatedClient(std::string const &server_address, int rank) - : stub_{[&] { - grpc::ChannelArguments args; - args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - return Federated::NewStub( - grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args)); - }()}, - rank_{rank} {} - - std::string Allgather(std::string_view send_buffer) { - AllgatherRequest request; - request.set_sequence_number(sequence_number_++); - request.set_rank(rank_); - request.set_send_buffer(send_buffer.data(), send_buffer.size()); - - AllgatherReply reply; - grpc::ClientContext context; - context.set_wait_for_ready(true); - grpc::Status status = stub_->Allgather(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << '\n'; - throw std::runtime_error("Allgather RPC failed"); - } - } - - std::string AllgatherV(std::string_view send_buffer) { - AllgatherVRequest request; - request.set_sequence_number(sequence_number_++); - request.set_rank(rank_); - request.set_send_buffer(send_buffer.data(), send_buffer.size()); - - AllgatherVReply reply; - grpc::ClientContext context; - context.set_wait_for_ready(true); - grpc::Status status = stub_->AllgatherV(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << '\n'; - throw std::runtime_error("AllgatherV RPC failed"); - } - } - - std::string Allreduce(std::string const &send_buffer, DataType data_type, - ReduceOperation reduce_operation) { - AllreduceRequest request; - request.set_sequence_number(sequence_number_++); - request.set_rank(rank_); - request.set_send_buffer(send_buffer); - request.set_data_type(data_type); - request.set_reduce_operation(reduce_operation); - - AllreduceReply reply; - grpc::ClientContext context; - context.set_wait_for_ready(true); - grpc::Status status = stub_->Allreduce(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << '\n'; - throw std::runtime_error("Allreduce RPC failed"); - } - } - - std::string Broadcast(std::string const &send_buffer, int root) { - BroadcastRequest request; - request.set_sequence_number(sequence_number_++); - request.set_rank(rank_); - request.set_send_buffer(send_buffer); - request.set_root(root); - - BroadcastReply reply; - grpc::ClientContext context; - context.set_wait_for_ready(true); - grpc::Status status = stub_->Broadcast(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << '\n'; - throw std::runtime_error("Broadcast RPC failed"); - } - } - - private: - std::unique_ptr const stub_; - int const rank_; - uint64_t sequence_number_{}; -}; -} // namespace xgboost::federated diff --git a/plugin/federated/federated_coll.cc b/plugin/federated/federated_coll.cc index 980992d61100..b62abdada5a5 100644 --- a/plugin/federated/federated_coll.cc +++ b/plugin/federated/federated_coll.cc @@ -89,19 +89,15 @@ Coll *FederatedColl::MakeCUDAVar() { [[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span data, std::int32_t root) { - if (comm.Rank() == root) { - return BroadcastImpl(comm, &this->sequence_number_, data, root); - } else { - return BroadcastImpl(comm, &this->sequence_number_, data, root); - } + return BroadcastImpl(comm, &this->sequence_number_, data, root); } -[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data, - std::int64_t size) { +[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data) { using namespace federated; // NOLINT auto fed = dynamic_cast(&comm); CHECK(fed); auto stub = fed->Handle(); + auto size = data.size_bytes() / comm.World(); auto offset = comm.Rank() * size; auto segment = data.subspan(offset, size); diff --git a/plugin/federated/federated_coll.cu b/plugin/federated/federated_coll.cu index a922e1c11483..3f604c50d2d2 100644 --- a/plugin/federated/federated_coll.cu +++ b/plugin/federated/federated_coll.cu @@ -53,8 +53,7 @@ Coll *FederatedColl::MakeCUDAVar() { }; } -[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data, - std::int64_t size) { +[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data) { auto cufed = dynamic_cast(&comm); CHECK(cufed); std::vector h_data(data.size()); @@ -63,7 +62,7 @@ Coll *FederatedColl::MakeCUDAVar() { return GetCUDAResult( cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); } << [&] { - return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size); + return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}); } << [&] { return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), cudaMemcpyHostToDevice, cufed->Stream())); diff --git a/plugin/federated/federated_coll.cuh b/plugin/federated/federated_coll.cuh index a1121d88f533..6a690a33d889 100644 --- a/plugin/federated/federated_coll.cuh +++ b/plugin/federated/federated_coll.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #include "../../src/collective/comm.h" // for Comm, Coll #include "federated_coll.h" // for FederatedColl @@ -16,8 +16,7 @@ class CUDAFederatedColl : public Coll { ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const &, common::Span data, - std::int64_t size) override; + [[nodiscard]] Result Allgather(Comm const &, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, common::Span sizes, common::Span recv_segments, diff --git a/plugin/federated/federated_coll.h b/plugin/federated/federated_coll.h index c261b01e11c6..12443a3e1b5a 100644 --- a/plugin/federated/federated_coll.h +++ b/plugin/federated/federated_coll.h @@ -1,12 +1,9 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #pragma once #include "../../src/collective/coll.h" // for Coll #include "../../src/collective/comm.h" // for Comm -#include "../../src/common/io.h" // for ReadAll -#include "../../src/common/json_utils.h" // for OptionalArg -#include "xgboost/json.h" // for Json namespace xgboost::collective { class FederatedColl : public Coll { @@ -20,8 +17,7 @@ class FederatedColl : public Coll { ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const &, common::Span data, - std::int64_t) override; + [[nodiscard]] Result Allgather(Comm const &, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, common::Span sizes, common::Span recv_segments, diff --git a/plugin/federated/federated_comm.cuh b/plugin/federated/federated_comm.cuh index 58c52f67e28c..85cecb3eb331 100644 --- a/plugin/federated/federated_comm.cuh +++ b/plugin/federated/federated_comm.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once @@ -9,7 +9,6 @@ #include "../../src/common/device_helpers.cuh" // for CUDAStreamView #include "federated_comm.h" // for FederatedComm #include "xgboost/context.h" // for Context -#include "xgboost/logging.h" namespace xgboost::collective { class CUDAFederatedComm : public FederatedComm { diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h index 750d94abd7dc..b39e1878a8ea 100644 --- a/plugin/federated/federated_comm.h +++ b/plugin/federated/federated_comm.h @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #pragma once @@ -11,7 +11,6 @@ #include // for string #include "../../src/collective/comm.h" // for HostComm -#include "../../src/common/json_utils.h" // for OptionalArg #include "xgboost/json.h" namespace xgboost::collective { @@ -51,6 +50,10 @@ class FederatedComm : public HostComm { std::int32_t rank) { this->Init(host, port, world, rank, {}, {}, {}); } + [[nodiscard]] Result Shutdown() final { + this->ResetState(); + return Success(); + } ~FederatedComm() override { stub_.reset(); } [[nodiscard]] std::shared_ptr Chan(std::int32_t) const override { @@ -65,5 +68,13 @@ class FederatedComm : public HostComm { [[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); } [[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const override; + /** + * @brief Get a string ID for the current process. + */ + [[nodiscard]] Result ProcessorName(std::string* out) const final { + auto rank = this->Rank(); + *out = "rank:" + std::to_string(rank); + return Success(); + }; }; } // namespace xgboost::collective diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h deleted file mode 100644 index 46c6b0fda672..000000000000 --- a/plugin/federated/federated_communicator.h +++ /dev/null @@ -1,195 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include - -#include "../../src/c_api/c_api_utils.h" -#include "../../src/collective/communicator.h" -#include "../../src/common/io.h" -#include "federated_client.h" - -namespace xgboost::collective { -/** - * @brief A Federated Learning communicator class that handles collective communication. - */ -class FederatedCommunicator : public Communicator { - public: - /** - * @brief Create a new communicator based on JSON configuration. - * @param config JSON configuration. - * @return Communicator as specified by the JSON configuration. - */ - static Communicator *Create(Json const &config) { - std::string server_address{}; - int world_size{0}; - int rank{-1}; - std::string server_cert{}; - std::string client_key{}; - std::string client_cert{}; - - // Parse environment variables first. - auto *value = getenv("FEDERATED_SERVER_ADDRESS"); - if (value != nullptr) { - server_address = value; - } - value = getenv("FEDERATED_WORLD_SIZE"); - if (value != nullptr) { - world_size = std::stoi(value); - } - value = getenv("FEDERATED_RANK"); - if (value != nullptr) { - rank = std::stoi(value); - } - value = getenv("FEDERATED_SERVER_CERT"); - if (value != nullptr) { - server_cert = value; - } - value = getenv("FEDERATED_CLIENT_KEY"); - if (value != nullptr) { - client_key = value; - } - value = getenv("FEDERATED_CLIENT_CERT"); - if (value != nullptr) { - client_cert = value; - } - - // Runtime configuration overrides, optional as users can specify them as env vars. - server_address = OptionalArg(config, "federated_server_address", server_address); - world_size = - OptionalArg(config, "federated_world_size", static_cast(world_size)); - rank = OptionalArg(config, "federated_rank", static_cast(rank)); - server_cert = OptionalArg(config, "federated_server_cert", server_cert); - client_key = OptionalArg(config, "federated_client_key", client_key); - client_cert = OptionalArg(config, "federated_client_cert", client_cert); - - if (server_address.empty()) { - LOG(FATAL) << "Federated server address must be set."; - } - if (world_size == 0) { - LOG(FATAL) << "Federated world size must be set."; - } - if (rank == -1) { - LOG(FATAL) << "Federated rank must be set."; - } - return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key, - client_cert); - } - - /** - * @brief Construct a new federated communicator. - * - * @param world_size Total number of processes. - * @param rank Rank of the current process. - * @param server_address Address of the federated server (host:port). - * @param server_cert_path Path to the server cert file. - * @param client_key_path Path to the client key file. - * @param client_cert_path Path to the client cert file. - */ - FederatedCommunicator(int world_size, int rank, std::string const &server_address, - std::string const &server_cert_path, std::string const &client_key_path, - std::string const &client_cert_path) - : Communicator{world_size, rank} { - if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) { - client_.reset(new xgboost::federated::FederatedClient(server_address, rank)); - } else { - client_.reset(new xgboost::federated::FederatedClient( - server_address, rank, xgboost::common::ReadAll(server_cert_path), - xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path))); - } - } - - /** - * @brief Construct an insecure federated communicator without using SSL. - * @param world_size Total number of processes. - * @param rank Rank of the current process. - * @param server_address Address of the federated server (host:port). - */ - FederatedCommunicator(int world_size, int rank, std::string const &server_address) - : Communicator{world_size, rank} { - client_.reset(new xgboost::federated::FederatedClient(server_address, rank)); - } - - ~FederatedCommunicator() override { client_.reset(); } - - /** - * \brief Get if the communicator is distributed. - * \return True. - */ - [[nodiscard]] bool IsDistributed() const override { return true; } - - /** - * \brief Get if the communicator is federated. - * \return True. - */ - [[nodiscard]] bool IsFederated() const override { return true; } - - /** - * \brief Perform allgather. - * \param input Buffer for sending data. - */ - std::string AllGather(std::string_view input) override { - return client_->Allgather(input); - } - - /** - * \brief Perform variable-length allgather. - * \param input Buffer for sending data. - */ - std::string AllGatherV(std::string_view input) override { - return client_->AllgatherV(input); - } - - /** - * \brief Perform in-place allreduce. - * \param send_receive_buffer Buffer for both sending and receiving data. - * \param count Number of elements to be reduced. - * \param data_type Enumeration of data type. - * \param op Enumeration of operation type. - */ - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override { - std::string const send_buffer(reinterpret_cast(send_receive_buffer), - count * GetTypeSize(data_type)); - auto const received = - client_->Allreduce(send_buffer, static_cast(data_type), - static_cast(op)); - received.copy(reinterpret_cast(send_receive_buffer), count * GetTypeSize(data_type)); - } - - /** - * \brief Broadcast a memory region to all others from root. - * \param send_receive_buffer Pointer to the send or receive buffer. - * \param size Size of the data. - * \param root The process rank to broadcast from. - */ - void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { - if (GetWorldSize() == 1) return; - if (GetRank() == root) { - std::string const send_buffer(reinterpret_cast(send_receive_buffer), size); - client_->Broadcast(send_buffer, root); - } else { - auto const received = client_->Broadcast("", root); - received.copy(reinterpret_cast(send_receive_buffer), size); - } - } - - /** - * \brief Get the name of the processor. - * \return Name of the processor. - */ - std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); } - - /** - * \brief Print the message to the communicator. - * \param message The message to be printed. - */ - void Print(const std::string &message) override { LOG(CONSOLE) << message; } - - protected: - void Shutdown() override {} - - private: - std::unique_ptr client_{}; -}; -} // namespace xgboost::collective diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc deleted file mode 100644 index 9dd97c2e19a7..000000000000 --- a/plugin/federated/federated_server.cc +++ /dev/null @@ -1,86 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "federated_server.h" - -#include -#include // for Server -#include -#include - -#include - -#include "../../src/collective/comm.h" -#include "../../src/common/io.h" -#include "../../src/common/json_utils.h" - -namespace xgboost::federated { -grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request, - AllgatherReply* reply) { - handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(), - reply->mutable_receive_buffer(), request->sequence_number(), request->rank()); - return grpc::Status::OK; -} - -grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request, - AllgatherVReply* reply) { - handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(), - reply->mutable_receive_buffer(), request->sequence_number(), request->rank()); - return grpc::Status::OK; -} - -grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request, - AllreduceReply* reply) { - handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(), - reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), - static_cast(request->data_type()), - static_cast(request->reduce_operation())); - return grpc::Status::OK; -} - -grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request, - BroadcastReply* reply) { - handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(), - reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), - request->root()); - return grpc::Status::OK; -} - -void RunServer(int port, std::size_t world_size, char const* server_key_file, - char const* server_cert_file, char const* client_cert_file) { - std::string const server_address = "0.0.0.0:" + std::to_string(port); - FederatedService service{static_cast(world_size)}; - - grpc::ServerBuilder builder; - auto options = - grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); - options.pem_root_certs = xgboost::common::ReadAll(client_cert_file); - auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); - key.private_key = xgboost::common::ReadAll(server_key_file); - key.cert_chain = xgboost::common::ReadAll(server_cert_file); - options.pem_key_cert_pairs.push_back(key); - builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); - builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); - builder.RegisterService(&service); - std::unique_ptr server(builder.BuildAndStart()); - LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size " - << world_size; - - server->Wait(); -} - -void RunInsecureServer(int port, std::size_t world_size) { - std::string const server_address = "0.0.0.0:" + std::to_string(port); - FederatedService service{static_cast(world_size)}; - - grpc::ServerBuilder builder; - builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); - builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); - builder.RegisterService(&service); - std::unique_ptr server(builder.BuildAndStart()); - LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size " - << world_size; - - server->Wait(); -} -} // namespace xgboost::federated diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h deleted file mode 100644 index de760d9d809e..000000000000 --- a/plugin/federated/federated_server.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2022-2023, XGBoost contributors - */ -#pragma once - -#include - -#include // for int32_t -#include // for future - -#include "../../src/collective/in_memory_handler.h" -#include "../../src/collective/tracker.h" // for Tracker -#include "xgboost/collective/result.h" // for Result - -namespace xgboost::federated { -class FederatedService final : public Federated::Service { - public: - explicit FederatedService(std::int32_t world_size) - : handler_{static_cast(world_size)} {} - - grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, - AllgatherReply* reply) override; - - grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request, - AllgatherVReply* reply) override; - - grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, - AllreduceReply* reply) override; - - grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, - BroadcastReply* reply) override; - - private: - xgboost::collective::InMemoryHandler handler_; -}; - -void RunServer(int port, std::size_t world_size, char const* server_key_file, - char const* server_cert_file, char const* client_cert_file); - -void RunInsecureServer(int port, std::size_t world_size); -} // namespace xgboost::federated diff --git a/plugin/federated/federated_tracker.cc b/plugin/federated/federated_tracker.cc index 37b6c36393e1..95c0824d9736 100644 --- a/plugin/federated/federated_tracker.cc +++ b/plugin/federated/federated_tracker.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023, XGBoost contributors + * Copyright 2022-2024, XGBoost contributors */ #include "federated_tracker.h" @@ -8,13 +8,12 @@ #include // for int32_t #include // for exception +#include // for future, async #include // for numeric_limits #include // for string -#include // for sleep_for #include "../../src/common/io.h" // for ReadAll #include "../../src/common/json_utils.h" // for RequiredArg -#include "../../src/common/timer.h" // for Timer namespace xgboost::collective { namespace federated { @@ -36,8 +35,8 @@ grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest AllreduceReply* reply) { handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(), reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), - static_cast(request->data_type()), - static_cast(request->reduce_operation())); + static_cast(request->data_type()), + static_cast(request->reduce_operation())); return grpc::Status::OK; } @@ -53,9 +52,13 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} { auto is_secure = RequiredArg(config, "federated_secure", __func__); if (is_secure) { + StringView msg{"Empty certificate path."}; server_key_path_ = RequiredArg(config, "server_key_path", __func__); + CHECK(!server_key_path_.empty()) << msg; server_cert_file_ = RequiredArg(config, "server_cert_path", __func__); + CHECK(!server_cert_file_.empty()) << msg; client_cert_file_ = RequiredArg(config, "client_cert_path", __func__); + CHECK(!client_cert_file_.empty()) << msg; } } @@ -125,14 +128,14 @@ Result FederatedTracker::Shutdown() { [[nodiscard]] Json FederatedTracker::WorkerArgs() const { auto rc = this->WaitUntilReady(); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); std::string host; rc = GetHostAddress(&host); CHECK(rc.OK()); Json args{Object{}}; - args["DMLC_TRACKER_URI"] = String{host}; - args["DMLC_TRACKER_PORT"] = this->Port(); + args["dmlc_tracker_uri"] = String{host}; + args["dmlc_tracker_port"] = this->Port(); return args; } } // namespace xgboost::collective diff --git a/plugin/federated/federated_tracker.h b/plugin/federated/federated_tracker.h index 33592fefea0c..ac46b6eaa183 100644 --- a/plugin/federated/federated_tracker.h +++ b/plugin/federated/federated_tracker.h @@ -17,8 +17,7 @@ namespace xgboost::collective { namespace federated { class FederatedService final : public Federated::Service { public: - explicit FederatedService(std::int32_t world_size) - : handler_{static_cast(world_size)} {} + explicit FederatedService(std::int32_t world_size) : handler_{world_size} {} grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, AllgatherReply* reply) override; diff --git a/plugin/sycl/common/hist_util.cc b/plugin/sycl/common/hist_util.cc new file mode 100644 index 000000000000..fd813a92cec9 --- /dev/null +++ b/plugin/sycl/common/hist_util.cc @@ -0,0 +1,334 @@ +/*! + * Copyright 2017-2023 by Contributors + * \file hist_util.cc + */ +#include +#include +#include + +#include "../data/gradient_index.h" +#include "hist_util.h" + +#include + +namespace xgboost { +namespace sycl { +namespace common { + +/*! + * \brief Fill histogram with zeroes + */ +template +void InitHist(::sycl::queue qu, GHistRow* hist, + size_t size, ::sycl::event* event) { + *event = qu.fill(hist->Begin(), + xgboost::detail::GradientPairInternal(), size, *event); +} +template void InitHist(::sycl::queue qu, + GHistRow* hist, + size_t size, ::sycl::event* event); +template void InitHist(::sycl::queue qu, + GHistRow* hist, + size_t size, ::sycl::event* event); + +/*! + * \brief Compute Subtraction: dst = src1 - src2 + */ +template +::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow* dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv) { + GradientSumT* pdst = reinterpret_cast(dst->Data()); + const GradientSumT* psrc1 = reinterpret_cast(src1.DataConst()); + const GradientSumT* psrc2 = reinterpret_cast(src2.DataConst()); + + auto event_final = qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event_priv); + cgh.parallel_for<>(::sycl::range<1>(2 * size), [pdst, psrc1, psrc2](::sycl::item<1> pid) { + const size_t i = pid.get_id(0); + pdst[i] = psrc1[i] - psrc2[i]; + }); + }); + return event_final; +} +template ::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow* dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv); +template ::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow* dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv); + +// Kernel with buffer using +template +::sycl::event BuildHistKernel(::sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + GHistRow* hist_buffer, + ::sycl::event event_priv) { + const size_t size = row_indices.Size(); + const size_t* rid = row_indices.begin; + const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; + const GradientPair::ValueT* pgh = + reinterpret_cast(gpair_device.DataConst()); + const BinIdxType* gradient_index = gmat.index.data(); + const uint32_t* offsets = gmat.index.Offset(); + FPType* hist_data = reinterpret_cast(hist->Data()); + const size_t nbins = gmat.nbins; + + const size_t max_work_group_size = + qu.get_device().get_info<::sycl::info::device::max_work_group_size>(); + const size_t work_group_size = n_columns < max_work_group_size ? n_columns : max_work_group_size; + + const size_t max_nblocks = hist_buffer->Size() / (nbins * 2); + const size_t min_block_size = 128; + size_t nblocks = std::min(max_nblocks, size / min_block_size + !!(size % min_block_size)); + const size_t block_size = size / nblocks + !!(size % nblocks); + FPType* hist_buffer_data = reinterpret_cast(hist_buffer->Data()); + + auto event_fill = qu.fill(hist_buffer_data, FPType(0), nblocks * nbins * 2, event_priv); + auto event_main = qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event_fill); + cgh.parallel_for<>(::sycl::nd_range<2>(::sycl::range<2>(nblocks, work_group_size), + ::sycl::range<2>(1, work_group_size)), + [=](::sycl::nd_item<2> pid) { + size_t block = pid.get_global_id(0); + size_t feat = pid.get_global_id(1); + + FPType* hist_local = hist_buffer_data + block * nbins * 2; + for (size_t idx = 0; idx < block_size; ++idx) { + size_t i = block * block_size + idx; + if (i < size) { + const size_t icol_start = n_columns * rid[i]; + const size_t idx_gh = rid[i]; + + pid.barrier(::sycl::access::fence_space::local_space); + const BinIdxType* gr_index_local = gradient_index + icol_start; + + for (size_t j = feat; j < n_columns; j += work_group_size) { + uint32_t idx_bin = static_cast(gr_index_local[j]); + if constexpr (isDense) { + idx_bin += offsets[j]; + } + if (idx_bin < nbins) { + hist_local[2 * idx_bin] += pgh[2 * idx_gh]; + hist_local[2 * idx_bin+1] += pgh[2 * idx_gh+1]; + } + } + } + } + }); + }); + + auto event_save = qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event_main); + cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) { + size_t idx_bin = pid.get_id(0); + + FPType gsum = 0.0f; + FPType hsum = 0.0f; + + for (size_t j = 0; j < nblocks; ++j) { + gsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin]; + hsum += hist_buffer_data[j * nbins * 2 + 2 * idx_bin + 1]; + } + + hist_data[2 * idx_bin] = gsum; + hist_data[2 * idx_bin + 1] = hsum; + }); + }); + return event_save; +} + +// Kernel with atomic using +template +::sycl::event BuildHistKernel(::sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + ::sycl::event event_priv) { + const size_t size = row_indices.Size(); + const size_t* rid = row_indices.begin; + const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; + const GradientPair::ValueT* pgh = + reinterpret_cast(gpair_device.DataConst()); + const BinIdxType* gradient_index = gmat.index.data(); + const uint32_t* offsets = gmat.index.Offset(); + FPType* hist_data = reinterpret_cast(hist->Data()); + const size_t nbins = gmat.nbins; + + const size_t max_work_group_size = + qu.get_device().get_info<::sycl::info::device::max_work_group_size>(); + const size_t feat_local = n_columns < max_work_group_size ? n_columns : max_work_group_size; + + auto event_fill = qu.fill(hist_data, FPType(0), nbins * 2, event_priv); + auto event_main = qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event_fill); + cgh.parallel_for<>(::sycl::range<2>(size, feat_local), + [=](::sycl::item<2> pid) { + size_t i = pid.get_id(0); + size_t feat = pid.get_id(1); + + const size_t icol_start = n_columns * rid[i]; + const size_t idx_gh = rid[i]; + + const BinIdxType* gr_index_local = gradient_index + icol_start; + + for (size_t j = feat; j < n_columns; j += feat_local) { + uint32_t idx_bin = static_cast(gr_index_local[j]); + if constexpr (isDense) { + idx_bin += offsets[j]; + } + if (idx_bin < nbins) { + AtomicRef gsum(hist_data[2 * idx_bin]); + AtomicRef hsum(hist_data[2 * idx_bin + 1]); + gsum.fetch_add(pgh[2 * idx_gh]); + hsum.fetch_add(pgh[2 * idx_gh + 1]); + } + } + }); + }); + return event_main; +} + +template +::sycl::event BuildHistDispatchKernel( + ::sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + bool isDense, + GHistRow* hist_buffer, + ::sycl::event events_priv, + bool force_atomic_use) { + const size_t size = row_indices.Size(); + const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride; + const size_t nbins = gmat.nbins; + + // max cycle size, while atomics are still effective + const size_t max_cycle_size_atomics = nbins; + const size_t cycle_size = size; + + // TODO(razdoburdin): replace the add-hock dispatching criteria by more sutable one + bool use_atomic = (size < nbins) || (gmat.max_num_bins == gmat.nbins / n_columns); + + // force_atomic_use flag is used only for testing + use_atomic = use_atomic || force_atomic_use; + if (!use_atomic) { + if (isDense) { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, hist_buffer, + events_priv); + } else { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, hist_buffer, + events_priv); + } + } else { + if (isDense) { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, events_priv); + } else { + return BuildHistKernel(qu, gpair_device, row_indices, + gmat, hist, events_priv); + } + } +} + +template +::sycl::event BuildHistKernel(::sycl::queue qu, + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, const bool isDense, + GHistRow* hist, + GHistRow* hist_buffer, + ::sycl::event event_priv, + bool force_atomic_use) { + const bool is_dense = isDense; + switch (gmat.index.GetBinTypeSize()) { + case BinTypeSize::kUint8BinsTypeSize: + return BuildHistDispatchKernel(qu, gpair_device, row_indices, + gmat, hist, is_dense, hist_buffer, + event_priv, force_atomic_use); + break; + case BinTypeSize::kUint16BinsTypeSize: + return BuildHistDispatchKernel(qu, gpair_device, row_indices, + gmat, hist, is_dense, hist_buffer, + event_priv, force_atomic_use); + break; + case BinTypeSize::kUint32BinsTypeSize: + return BuildHistDispatchKernel(qu, gpair_device, row_indices, + gmat, hist, is_dense, hist_buffer, + event_priv, force_atomic_use); + break; + default: + CHECK(false); // no default behavior + } +} + +template +::sycl::event GHistBuilder::BuildHist( + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix &gmat, + GHistRowT* hist, + bool isDense, + GHistRowT* hist_buffer, + ::sycl::event event_priv, + bool force_atomic_use) { + return BuildHistKernel(qu_, gpair_device, row_indices, gmat, + isDense, hist, hist_buffer, event_priv, + force_atomic_use); +} + +template +::sycl::event GHistBuilder::BuildHist( + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + bool isDense, + GHistRow* hist_buffer, + ::sycl::event event_priv, + bool force_atomic_use); +template +::sycl::event GHistBuilder::BuildHist( + const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRow* hist, + bool isDense, + GHistRow* hist_buffer, + ::sycl::event event_priv, + bool force_atomic_use); + +template +void GHistBuilder::SubtractionTrick(GHistRowT* self, + const GHistRowT& sibling, + const GHistRowT& parent) { + const size_t size = self->Size(); + CHECK_EQ(sibling.Size(), size); + CHECK_EQ(parent.Size(), size); + + SubtractionHist(qu_, self, parent, sibling, size, ::sycl::event()); +} +template +void GHistBuilder::SubtractionTrick(GHistRow* self, + const GHistRow& sibling, + const GHistRow& parent); +template +void GHistBuilder::SubtractionTrick(GHistRow* self, + const GHistRow& sibling, + const GHistRow& parent); +} // namespace common +} // namespace sycl +} // namespace xgboost diff --git a/plugin/sycl/common/hist_util.h b/plugin/sycl/common/hist_util.h new file mode 100644 index 000000000000..aa9b4f5817bb --- /dev/null +++ b/plugin/sycl/common/hist_util.h @@ -0,0 +1,176 @@ +/*! + * Copyright 2017-2023 by Contributors + * \file hist_util.h + */ +#ifndef PLUGIN_SYCL_COMMON_HIST_UTIL_H_ +#define PLUGIN_SYCL_COMMON_HIST_UTIL_H_ + +#include +#include +#include + +#include "../data.h" +#include "row_set.h" + +#include "../../src/common/hist_util.h" +#include "../data/gradient_index.h" + +#include + +namespace xgboost { +namespace sycl { +namespace common { + +template +using GHistRow = USMVector, memory_type>; + +using BinTypeSize = ::xgboost::common::BinTypeSize; + +class ColumnMatrix; + +/*! + * \brief Fill histogram with zeroes + */ +template +void InitHist(::sycl::queue qu, + GHistRow* hist, + size_t size, ::sycl::event* event); + +/*! + * \brief Compute subtraction: dst = src1 - src2 + */ +template +::sycl::event SubtractionHist(::sycl::queue qu, + GHistRow* dst, + const GHistRow& src1, + const GHistRow& src2, + size_t size, ::sycl::event event_priv); + +/*! + * \brief Histograms of gradient statistics for multiple nodes + */ +template +class HistCollection { + public: + using GHistRowT = GHistRow; + + // Access histogram for i-th node + GHistRowT& operator[](bst_uint nid) { + return *(data_.at(nid)); + } + + const GHistRowT& operator[](bst_uint nid) const { + return *(data_.at(nid)); + } + + // Initialize histogram collection + void Init(::sycl::queue qu, uint32_t nbins) { + qu_ = qu; + if (nbins_ != nbins) { + nbins_ = nbins; + data_.clear(); + } + } + + // Create an empty histogram for i-th node + ::sycl::event AddHistRow(bst_uint nid) { + ::sycl::event event; + if (data_.count(nid) == 0) { + data_[nid] = + std::make_shared(&qu_, nbins_, + xgboost::detail::GradientPairInternal(0, 0), + &event); + } else { + data_[nid]->Resize(&qu_, nbins_, + xgboost::detail::GradientPairInternal(0, 0), + &event); + } + return event; + } + + private: + /*! \brief Number of all bins over all features */ + uint32_t nbins_ = 0; + + std::unordered_map> data_; + + ::sycl::queue qu_; +}; + +/*! + * \brief Stores temporary histograms to compute them in parallel + */ +template +class ParallelGHistBuilder { + public: + using GHistRowT = GHistRow; + + void Init(::sycl::queue qu, size_t nbins) { + qu_ = qu; + if (nbins != nbins_) { + hist_buffer_.Init(qu_, nbins); + nbins_ = nbins; + } + } + + void Reset(size_t nblocks) { + hist_device_buffer_.Resize(&qu_, nblocks * nbins_ * 2); + } + + GHistRowT& GetDeviceBuffer() { + return hist_device_buffer_; + } + + protected: + /*! \brief Number of bins in each histogram */ + size_t nbins_ = 0; + /*! \brief Buffers for histograms for all nodes processed */ + HistCollection hist_buffer_; + + /*! \brief Buffer for additional histograms for Parallel processing */ + GHistRowT hist_device_buffer_; + + ::sycl::queue qu_; +}; + +/*! + * \brief Builder for histograms of gradient statistics + */ +template +class GHistBuilder { + public: + template + using GHistRowT = GHistRow; + + GHistBuilder() = default; + GHistBuilder(::sycl::queue qu, uint32_t nbins) : qu_{qu}, nbins_{nbins} {} + + // Construct a histogram via histogram aggregation + ::sycl::event BuildHist(const USMVector& gpair_device, + const RowSetCollection::Elem& row_indices, + const GHistIndexMatrix& gmat, + GHistRowT* HistCollection, + bool isDense, + GHistRowT* hist_buffer, + ::sycl::event event, + bool force_atomic_use = false); + + // Construct a histogram via subtraction trick + void SubtractionTrick(GHistRowT* self, + const GHistRowT& sibling, + const GHistRowT& parent); + + uint32_t GetNumBins() const { + return nbins_; + } + + private: + /*! \brief Number of all bins over all features */ + uint32_t nbins_ { 0 }; + + ::sycl::queue qu_; +}; +} // namespace common +} // namespace sycl +} // namespace xgboost +#endif // PLUGIN_SYCL_COMMON_HIST_UTIL_H_ diff --git a/plugin/sycl/common/partition_builder.h b/plugin/sycl/common/partition_builder.h index 37d1af241ab1..c520ff31fb8e 100644 --- a/plugin/sycl/common/partition_builder.h +++ b/plugin/sycl/common/partition_builder.h @@ -21,6 +21,9 @@ #pragma GCC diagnostic pop #include "../data.h" +#include "row_set.h" +#include "../data/gradient_index.h" +#include "../tree/expand_entry.h" #include @@ -28,6 +31,87 @@ namespace xgboost { namespace sycl { namespace common { +// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending +// on comparison of indexes values (idx_span) and split point (split_cond) +// Handle dense columns +template +inline ::sycl::event PartitionDenseKernel( + ::sycl::queue* qu, + const GHistIndexMatrix& gmat, + const RowSetCollection::Elem& rid_span, + const size_t fid, + const int32_t split_cond, + xgboost::common::Span* rid_buf, + size_t* parts_size, + ::sycl::event event) { + const size_t row_stride = gmat.row_stride; + const BinIdxType* gradient_index = gmat.index.data(); + const size_t* rid = rid_span.begin; + const size_t range_size = rid_span.Size(); + const size_t offset = gmat.cut.Ptrs()[fid]; + + size_t* p_rid_buf = rid_buf->data(); + + return qu->submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) { + const size_t id = rid[nid.get_id(0)]; + const int32_t value = static_cast(gradient_index[id * row_stride + fid] + offset); + const bool is_left = value <= split_cond; + if (is_left) { + AtomicRef n_left(parts_size[0]); + p_rid_buf[n_left.fetch_add(1)] = id; + } else { + AtomicRef n_right(parts_size[1]); + p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id; + } + }); + }); +} + +// split row indexes (rid_span) to 2 parts (both stored in rid_buf) depending +// on comparison of indexes values (idx_span) and split point (split_cond) +// Handle sparce columns +template +inline ::sycl::event PartitionSparseKernel(::sycl::queue* qu, + const GHistIndexMatrix& gmat, + const RowSetCollection::Elem& rid_span, + const size_t fid, + const int32_t split_cond, + xgboost::common::Span* rid_buf, + size_t* parts_size, + ::sycl::event event) { + const size_t row_stride = gmat.row_stride; + const BinIdxType* gradient_index = gmat.index.data(); + const size_t* rid = rid_span.begin; + const size_t range_size = rid_span.Size(); + const uint32_t* cut_ptrs = gmat.cut_device.Ptrs().DataConst(); + + size_t* p_rid_buf = rid_buf->data(); + return qu->submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + cgh.parallel_for<>(::sycl::range<1>(range_size), [=](::sycl::item<1> nid) { + const size_t id = rid[nid.get_id(0)]; + + const BinIdxType* gr_index_local = gradient_index + row_stride * id; + const int32_t fid_local = std::lower_bound(gr_index_local, + gr_index_local + row_stride, + cut_ptrs[fid]) - gr_index_local; + const bool is_left = (fid_local >= row_stride || + gr_index_local[fid_local] >= cut_ptrs[fid + 1]) ? + default_left : + gr_index_local[fid_local] <= split_cond; + if (is_left) { + AtomicRef n_left(parts_size[0]); + p_rid_buf[n_left.fetch_add(1)] = id; + } else { + AtomicRef n_right(parts_size[1]); + p_rid_buf[range_size - n_right.fetch_add(1) - 1] = id; + } + }); + }); +} + // The builder is required for samples partition to left and rights children for set of nodes class PartitionBuilder { public: @@ -53,7 +137,6 @@ class PartitionBuilder { return result_rows_[2 * nid]; } - size_t GetNRightElems(int nid) const { return result_rows_[2 * nid + 1]; } @@ -72,19 +155,97 @@ class PartitionBuilder { return { data_.Data() + nodes_offsets_[nid], nodes_offsets_[nid + 1] - nodes_offsets_[nid] }; } + template + ::sycl::event Partition(const int32_t split_cond, + const GHistIndexMatrix& gmat, + const RowSetCollection::Elem& rid_span, + const xgboost::RegTree::Node& node, + xgboost::common::Span* rid_buf, + size_t* parts_size, + ::sycl::event event) { + const bst_uint fid = node.SplitIndex(); + const bool default_left = node.DefaultLeft(); + + if (gmat.IsDense()) { + if (default_left) { + return PartitionDenseKernel(qu_, gmat, rid_span, fid, + split_cond, rid_buf, parts_size, event); + } else { + return PartitionDenseKernel(qu_, gmat, rid_span, fid, + split_cond, rid_buf, parts_size, event); + } + } else { + if (default_left) { + return PartitionSparseKernel(qu_, gmat, rid_span, fid, + split_cond, rid_buf, parts_size, event); + } else { + return PartitionSparseKernel(qu_, gmat, rid_span, fid, + split_cond, rid_buf, parts_size, event); + } + } + } + + // Entry point for Partition + void Partition(const GHistIndexMatrix& gmat, + const std::vector nodes, + const RowSetCollection& row_set_collection, + const std::vector& split_conditions, + RegTree* p_tree, + ::sycl::event* general_event) { + nodes_events_.resize(n_nodes_); + + parts_size_.ResizeAndFill(qu_, 2 * n_nodes_, 0, general_event); + + for (size_t node_in_set = 0; node_in_set < n_nodes_; node_in_set++) { + const int32_t nid = nodes[node_in_set].nid; + ::sycl::event& node_event = nodes_events_[node_in_set]; + const auto& rid_span = row_set_collection[nid]; + if (rid_span.Size() > 0) { + const RegTree::Node& node = (*p_tree)[nid]; + xgboost::common::Span rid_buf = GetData(node_in_set); + size_t* part_size = parts_size_.Data() + 2 * node_in_set; + int32_t split_condition = split_conditions[node_in_set]; + switch (gmat.index.GetBinTypeSize()) { + case common::BinTypeSize::kUint8BinsTypeSize: + node_event = Partition(split_condition, gmat, rid_span, node, + &rid_buf, part_size, *general_event); + break; + case common::BinTypeSize::kUint16BinsTypeSize: + node_event = Partition(split_condition, gmat, rid_span, node, + &rid_buf, part_size, *general_event); + break; + case common::BinTypeSize::kUint32BinsTypeSize: + node_event = Partition(split_condition, gmat, rid_span, node, + &rid_buf, part_size, *general_event); + break; + default: + CHECK(false); // no default behavior + } + } else { + node_event = ::sycl::event(); + } + } + + *general_event = qu_->memcpy(result_rows_.data(), + parts_size_.DataConst(), + sizeof(size_t) * 2 * n_nodes_, + nodes_events_); + } + void MergeToArray(size_t nid, size_t* data_result, - ::sycl::event event) { + ::sycl::event* event) { size_t n_nodes_total = GetNLeftElems(nid) + GetNRightElems(nid); if (n_nodes_total > 0) { const size_t* data = data_.Data() + nodes_offsets_[nid]; - qu_->memcpy(data_result, data, sizeof(size_t) * n_nodes_total, event); + qu_->memcpy(data_result, data, sizeof(size_t) * n_nodes_total, *event); } } protected: std::vector nodes_offsets_; std::vector result_rows_; + std::vector<::sycl::event> nodes_events_; size_t n_nodes_; USMVector parts_size_; diff --git a/plugin/sycl/data.h b/plugin/sycl/data.h index 37d5842bf9a4..8f4bb2516f05 100644 --- a/plugin/sycl/data.h +++ b/plugin/sycl/data.h @@ -80,6 +80,12 @@ class USMVector { qu->fill(data_.get(), v, size_).wait(); } + USMVector(::sycl::queue* qu, size_t size, T v, + ::sycl::event* event) : size_(size), capacity_(size) { + data_ = allocate_memory_(qu, size_); + *event = qu->fill(data_.get(), v, size_, *event); + } + USMVector(::sycl::queue* qu, const std::vector &vec) { size_ = vec.size(); capacity_ = size_; @@ -171,20 +177,20 @@ class USMVector { } } - ::sycl::event ResizeAndFill(::sycl::queue* qu, size_t size_new, int v) { + void ResizeAndFill(::sycl::queue* qu, size_t size_new, int v, ::sycl::event* event) { if (size_new <= size_) { size_ = size_new; - return qu->memset(data_.get(), v, size_new * sizeof(T)); + *event = qu->memset(data_.get(), v, size_new * sizeof(T), *event); } else if (size_new <= capacity_) { size_ = size_new; - return qu->memset(data_.get(), v, size_new * sizeof(T)); + *event = qu->memset(data_.get(), v, size_new * sizeof(T), *event); } else { size_t size_old = size_; auto data_old = data_; size_ = size_new; capacity_ = size_new; data_ = allocate_memory_(qu, size_); - return qu->memset(data_.get(), v, size_new * sizeof(T)); + *event = qu->memset(data_.get(), v, size_new * sizeof(T), *event); } } @@ -211,11 +217,16 @@ class USMVector { struct DeviceMatrix { DMatrix* p_mat; // Pointer to the original matrix on the host ::sycl::queue qu_; - USMVector row_ptr; + USMVector row_ptr; USMVector data; size_t total_offset; - DeviceMatrix(::sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) { + DeviceMatrix() = default; + + void Init(::sycl::queue qu, DMatrix* dmat) { + qu_ = qu; + p_mat = dmat; + size_t num_row = 0; size_t num_nonzero = 0; for (auto &batch : dmat->GetBatches()) { @@ -226,27 +237,41 @@ struct DeviceMatrix { } row_ptr.Resize(&qu_, num_row + 1); + size_t* rows = row_ptr.Data(); data.Resize(&qu_, num_nonzero); size_t data_offset = 0; + ::sycl::event event; for (auto &batch : dmat->GetBatches()) { const auto& data_vec = batch.data.HostVector(); const auto& offset_vec = batch.offset.HostVector(); size_t batch_size = batch.Size(); if (batch_size > 0) { - std::copy(offset_vec.data(), offset_vec.data() + batch_size, - row_ptr.Data() + batch.base_rowid); - if (batch.base_rowid > 0) { - for (size_t i = 0; i < batch_size; i++) - row_ptr[i + batch.base_rowid] += batch.base_rowid; + const auto base_rowid = batch.base_rowid; + event = qu.memcpy(row_ptr.Data() + base_rowid, offset_vec.data(), + sizeof(size_t) * batch_size, event); + if (base_rowid > 0) { + qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + cgh.parallel_for<>(::sycl::range<1>(batch_size), [=](::sycl::id<1> pid) { + int row_id = pid[0]; + rows[row_id] += base_rowid; + }); + }); } - qu.memcpy(data.Data() + data_offset, - data_vec.data(), - offset_vec[batch_size] * sizeof(Entry)).wait(); + event = qu.memcpy(data.Data() + data_offset, data_vec.data(), + sizeof(Entry) * offset_vec[batch_size], event); data_offset += offset_vec[batch_size]; + qu.wait(); } } - row_ptr[num_row] = data_offset; + qu.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + cgh.single_task<>([=] { + rows[num_row] = data_offset; + }); + }); + qu.wait(); total_offset = data_offset; } diff --git a/plugin/sycl/data/gradient_index.cc b/plugin/sycl/data/gradient_index.cc index 49b66a71052f..e193b66894c9 100644 --- a/plugin/sycl/data/gradient_index.cc +++ b/plugin/sycl/data/gradient_index.cc @@ -57,7 +57,7 @@ void GHistIndexMatrix::SetIndexData(::sycl::queue qu, uint32_t* offsets) { if (nbins == 0) return; const xgboost::Entry *data_ptr = dmat.data.DataConst(); - const bst_row_t *offset_vec = dmat.row_ptr.DataConst(); + const bst_idx_t *offset_vec = dmat.row_ptr.DataConst(); const size_t num_rows = dmat.row_ptr.Size() - 1; const bst_float* cut_values = cut_device.Values().DataConst(); const uint32_t* cut_ptrs = cut_device.Ptrs().DataConst(); diff --git a/plugin/sycl/device_manager.cc b/plugin/sycl/device_manager.cc index 0254cdd6a396..0ddbf144083b 100644 --- a/plugin/sycl/device_manager.cc +++ b/plugin/sycl/device_manager.cc @@ -2,14 +2,10 @@ * Copyright 2017-2023 by Contributors * \file device_manager.cc */ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wtautological-constant-compare" -#pragma GCC diagnostic ignored "-W#pragma-messages" -#include -#pragma GCC diagnostic pop - #include "../sycl/device_manager.h" +#include "../../src/collective/communicator-inl.h" + namespace xgboost { namespace sycl { @@ -21,22 +17,23 @@ ::sycl::device DeviceManager::GetDevice(const DeviceOrd& device_spec) const { } bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) || - (rabit::IsDistributed()); + (collective::IsDistributed()); if (not_use_default_selector) { DeviceRegister& device_register = GetDevicesRegister(); - const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal; + const int device_idx = + collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal; if (device_spec.IsSyclDefault()) { - auto& devices = device_register.devices; - CHECK_LT(device_idx, devices.size()); - return devices[device_idx]; + auto& devices = device_register.devices; + CHECK_LT(device_idx, devices.size()); + return devices[device_idx]; } else if (device_spec.IsSyclCPU()) { - auto& cpu_devices = device_register.cpu_devices; - CHECK_LT(device_idx, cpu_devices.size()); - return cpu_devices[device_idx]; + auto& cpu_devices = device_register.cpu_devices; + CHECK_LT(device_idx, cpu_devices.size()); + return cpu_devices[device_idx]; } else { - auto& gpu_devices = device_register.gpu_devices; - CHECK_LT(device_idx, gpu_devices.size()); - return gpu_devices[device_idx]; + auto& gpu_devices = device_register.gpu_devices; + CHECK_LT(device_idx, gpu_devices.size()); + return gpu_devices[device_idx]; } } else { if (device_spec.IsSyclCPU()) { @@ -62,24 +59,25 @@ ::sycl::queue DeviceManager::GetQueue(const DeviceOrd& device_spec) const { } bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) || - (rabit::IsDistributed()); + (collective::IsDistributed()); std::lock_guard guard(queue_registering_mutex); if (not_use_default_selector) { - DeviceRegister& device_register = GetDevicesRegister(); - const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal; - if (device_spec.IsSyclDefault()) { - auto& devices = device_register.devices; - CHECK_LT(device_idx, devices.size()); - queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]); - } else if (device_spec.IsSyclCPU()) { - auto& cpu_devices = device_register.cpu_devices; - CHECK_LT(device_idx, cpu_devices.size()); - queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);; - } else if (device_spec.IsSyclGPU()) { - auto& gpu_devices = device_register.gpu_devices; - CHECK_LT(device_idx, gpu_devices.size()); - queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]); - } + DeviceRegister& device_register = GetDevicesRegister(); + const int device_idx = + collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal; + if (device_spec.IsSyclDefault()) { + auto& devices = device_register.devices; + CHECK_LT(device_idx, devices.size()); + queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]); + } else if (device_spec.IsSyclCPU()) { + auto& cpu_devices = device_register.cpu_devices; + CHECK_LT(device_idx, cpu_devices.size()); + queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]); + } else if (device_spec.IsSyclGPU()) { + auto& gpu_devices = device_register.gpu_devices; + CHECK_LT(device_idx, gpu_devices.size()); + queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]); + } } else { if (device_spec.IsSyclCPU()) { queue_register[device_spec.Name()] = ::sycl::queue(::sycl::cpu_selector_v); diff --git a/plugin/sycl/device_manager.h b/plugin/sycl/device_manager.h index 0ae2ee9fed61..84d4b24c0aa8 100644 --- a/plugin/sycl/device_manager.h +++ b/plugin/sycl/device_manager.h @@ -12,7 +12,11 @@ #include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" #include "xgboost/context.h" +#pragma GCC diagnostic pop namespace xgboost { namespace sycl { diff --git a/plugin/sycl/objective/multiclass_obj.cc b/plugin/sycl/objective/multiclass_obj.cc index 3104dd35e716..5dcc8c3de599 100644 --- a/plugin/sycl/objective/multiclass_obj.cc +++ b/plugin/sycl/objective/multiclass_obj.cc @@ -3,20 +3,15 @@ * \file multiclass_obj.cc * \brief Definition of multi-class classification objectives. */ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wtautological-constant-compare" -#pragma GCC diagnostic ignored "-W#pragma-messages" -#include -#pragma GCC diagnostic pop - #include #include #include #include -#include "xgboost/parameter.h" #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include "xgboost/parameter.h" #include "xgboost/data.h" #include "../../src/common/math.h" #pragma GCC diagnostic pop diff --git a/plugin/sycl/objective/regression_obj.cc b/plugin/sycl/objective/regression_obj.cc index 985498717aaf..82467a7c4848 100644 --- a/plugin/sycl/objective/regression_obj.cc +++ b/plugin/sycl/objective/regression_obj.cc @@ -9,7 +9,6 @@ #include #include #pragma GCC diagnostic pop -#include #include #include diff --git a/plugin/sycl/predictor/predictor.cc b/plugin/sycl/predictor/predictor.cc index dd56dd3bd462..c941bca102e7 100755 --- a/plugin/sycl/predictor/predictor.cc +++ b/plugin/sycl/predictor/predictor.cc @@ -4,7 +4,6 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wtautological-constant-compare" #pragma GCC diagnostic ignored "-W#pragma-messages" -#include #pragma GCC diagnostic pop #include @@ -280,7 +279,8 @@ class Predictor : public xgboost::Predictor { uint32_t tree_end = 0) const override { ::sycl::queue qu = device_manager.GetQueue(ctx_->Device()); // TODO(razdoburdin): remove temporary workaround after cache fix - sycl::DeviceMatrix device_matrix(qu, dmat); + sycl::DeviceMatrix device_matrix; + device_matrix.Init(qu, dmat); auto* out_preds = &predts->predictions; if (tree_end == 0) { diff --git a/plugin/sycl/tree/expand_entry.h b/plugin/sycl/tree/expand_entry.h new file mode 100644 index 000000000000..2520ff95db5a --- /dev/null +++ b/plugin/sycl/tree/expand_entry.h @@ -0,0 +1,50 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file expand_entry.h + */ +#ifndef PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_ +#define PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#include "../../src/tree/constraints.h" +#pragma GCC diagnostic pop +#include "../../src/tree/hist/expand_entry.h" + +namespace xgboost { +namespace sycl { +namespace tree { +/* tree growing policies */ +struct ExpandEntry : public xgboost::tree::ExpandEntryImpl { + static constexpr bst_node_t kRootNid = 0; + + xgboost::tree::SplitEntry split; + + ExpandEntry(int nid, int depth) : ExpandEntryImpl{nid, depth} {} + + inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree) const { + CHECK_EQ((*p_tree)[nid].IsRoot(), false); + const size_t parent_id = (*p_tree)[nid].Parent(); + return GetSiblingId(p_tree, parent_id); + } + + inline bst_node_t GetSiblingId(const xgboost::RegTree* p_tree, size_t parent_id) const { + return p_tree->IsLeftChild(nid) ? p_tree->RightChild(parent_id) + : p_tree->LeftChild(parent_id); + } + + bool IsValidImpl(xgboost::tree::TrainParam const ¶m, int32_t num_leaves) const { + if (split.loss_chg <= kRtEps) return false; + if (split.loss_chg < param.min_split_loss) return false; + if (param.max_depth > 0 && depth == param.max_depth) return false; + if (param.max_leaves > 0 && num_leaves == param.max_leaves) return false; + + return true; + } +}; + +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_EXPAND_ENTRY_H_ diff --git a/plugin/sycl/tree/hist_row_adder.h b/plugin/sycl/tree/hist_row_adder.h new file mode 100644 index 000000000000..968bcca737dc --- /dev/null +++ b/plugin/sycl/tree/hist_row_adder.h @@ -0,0 +1,46 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file hist_row_adder.h + */ +#ifndef PLUGIN_SYCL_TREE_HIST_ROW_ADDER_H_ +#define PLUGIN_SYCL_TREE_HIST_ROW_ADDER_H_ + +#include +#include + +namespace xgboost { +namespace sycl { +namespace tree { + +template +class HistRowsAdder { + public: + virtual void AddHistRows(HistUpdater* builder, + std::vector* sync_ids, RegTree *p_tree) = 0; + virtual ~HistRowsAdder() = default; +}; + +template +class BatchHistRowsAdder: public HistRowsAdder { + public: + void AddHistRows(HistUpdater* builder, + std::vector* sync_ids, RegTree *p_tree) override { + builder->builder_monitor_.Start("AddHistRows"); + + for (auto const& entry : builder->nodes_for_explicit_hist_build_) { + int nid = entry.nid; + auto event = builder->hist_.AddHistRow(nid); + } + for (auto const& node : builder->nodes_for_subtraction_trick_) { + auto event = builder->hist_.AddHistRow(node.nid); + } + + builder->builder_monitor_.Stop("AddHistRows"); + } +}; + +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_HIST_ROW_ADDER_H_ diff --git a/plugin/sycl/tree/hist_synchronizer.h b/plugin/sycl/tree/hist_synchronizer.h new file mode 100644 index 000000000000..2275a51dba37 --- /dev/null +++ b/plugin/sycl/tree/hist_synchronizer.h @@ -0,0 +1,68 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file hist_synchronizer.h + */ +#ifndef PLUGIN_SYCL_TREE_HIST_SYNCHRONIZER_H_ +#define PLUGIN_SYCL_TREE_HIST_SYNCHRONIZER_H_ + +#include + +#include "../common/hist_util.h" +#include "expand_entry.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +template +class HistUpdater; + +template +class HistSynchronizer { + public: + virtual void SyncHistograms(HistUpdater* builder, + const std::vector& sync_ids, + RegTree *p_tree) = 0; + virtual ~HistSynchronizer() = default; +}; + +template +class BatchHistSynchronizer: public HistSynchronizer { + public: + void SyncHistograms(HistUpdater* builder, + const std::vector& sync_ids, + RegTree *p_tree) override { + builder->builder_monitor_.Start("SyncHistograms"); + const size_t nbins = builder->hist_builder_.GetNumBins(); + + hist_sync_events_.resize(builder->nodes_for_explicit_hist_build_.size()); + for (int i = 0; i < builder->nodes_for_explicit_hist_build_.size(); i++) { + const auto entry = builder->nodes_for_explicit_hist_build_[i]; + auto& this_hist = builder->hist_[entry.nid]; + + if (!(*p_tree)[entry.nid].IsRoot()) { + const size_t parent_id = (*p_tree)[entry.nid].Parent(); + auto& parent_hist = builder->hist_[parent_id]; + auto& sibling_hist = builder->hist_[entry.GetSiblingId(p_tree, parent_id)]; + hist_sync_events_[i] = common::SubtractionHist(builder->qu_, &sibling_hist, parent_hist, + this_hist, nbins, ::sycl::event()); + } + } + builder->qu_.wait_and_throw(); + + builder->builder_monitor_.Stop("SyncHistograms"); + } + + std::vector<::sycl::event> GetEvents() const { + return hist_sync_events_; + } + + private: + std::vector<::sycl::event> hist_sync_events_; +}; + +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_HIST_SYNCHRONIZER_H_ diff --git a/plugin/sycl/tree/hist_updater.cc b/plugin/sycl/tree/hist_updater.cc new file mode 100644 index 000000000000..76ecdeab8ac3 --- /dev/null +++ b/plugin/sycl/tree/hist_updater.cc @@ -0,0 +1,320 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file hist_updater.cc + */ + +#include "hist_updater.h" + +#include + +#include "../common/hist_util.h" +#include "../../src/collective/allreduce.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +template +void HistUpdater::SetHistSynchronizer( + HistSynchronizer *sync) { + hist_synchronizer_.reset(sync); +} + +template +void HistUpdater::SetHistRowsAdder( + HistRowsAdder *adder) { + hist_rows_adder_.reset(adder); +} + +template +void HistUpdater::BuildHistogramsLossGuide( + ExpandEntry entry, + const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + const USMVector &gpair_device) { + nodes_for_explicit_hist_build_.clear(); + nodes_for_subtraction_trick_.clear(); + nodes_for_explicit_hist_build_.push_back(entry); + + if (!(*p_tree)[entry.nid].IsRoot()) { + auto sibling_id = entry.GetSiblingId(p_tree); + nodes_for_subtraction_trick_.emplace_back(sibling_id, p_tree->GetDepth(sibling_id)); + } + + std::vector sync_ids; + hist_rows_adder_->AddHistRows(this, &sync_ids, p_tree); + qu_.wait_and_throw(); + BuildLocalHistograms(gmat, p_tree, gpair_device); + hist_synchronizer_->SyncHistograms(this, sync_ids, p_tree); +} + +template +void HistUpdater::BuildLocalHistograms( + const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + const USMVector &gpair_device) { + builder_monitor_.Start("BuildLocalHistograms"); + const size_t n_nodes = nodes_for_explicit_hist_build_.size(); + ::sycl::event event; + + for (size_t i = 0; i < n_nodes; i++) { + const int32_t nid = nodes_for_explicit_hist_build_[i].nid; + + if (row_set_collection_[nid].Size() > 0) { + event = BuildHist(gpair_device, row_set_collection_[nid], gmat, &(hist_[nid]), + &(hist_buffer_.GetDeviceBuffer()), event); + } else { + common::InitHist(qu_, &(hist_[nid]), hist_[nid].Size(), &event); + } + } + qu_.wait_and_throw(); + builder_monitor_.Stop("BuildLocalHistograms"); +} + +template +void HistUpdater::InitSampling( + const USMVector &gpair, + USMVector* row_indices) { + const size_t num_rows = row_indices->Size(); + auto* row_idx = row_indices->Data(); + const auto* gpair_ptr = gpair.DataConst(); + uint64_t num_samples = 0; + const auto subsample = param_.subsample; + ::sycl::event event; + + { + ::sycl::buffer flag_buf(&num_samples, 1); + uint64_t seed = seed_; + seed_ += num_rows; + event = qu_.submit([&](::sycl::handler& cgh) { + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), + [=](::sycl::item<1> pid) { + uint64_t i = pid.get_id(0); + + // Create minstd_rand engine + oneapi::dpl::minstd_rand engine(seed, i); + oneapi::dpl::bernoulli_distribution coin_flip(subsample); + + auto rnd = coin_flip(engine); + if (gpair_ptr[i].GetHess() >= 0.0f && rnd) { + AtomicRef num_samples_ref(flag_buf_acc[0]); + row_idx[num_samples_ref++] = i; + } + }); + }); + /* After calling a destructor for flag_buf, content will be copyed to num_samples */ + } + + row_indices->Resize(&qu_, num_samples, 0, &event); + qu_.wait(); +} + +template +void HistUpdater::InitData( + const common::GHistIndexMatrix& gmat, + const USMVector &gpair, + const DMatrix& fmat, + const RegTree& tree) { + CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) + << "max_depth or max_leaves cannot be both 0 (unlimited); " + << "at least one should be a positive quantity."; + if (param_.grow_policy == xgboost::tree::TrainParam::kDepthWise) { + CHECK(param_.max_depth > 0) << "max_depth cannot be 0 (unlimited) " + << "when grow_policy is depthwise."; + } + builder_monitor_.Start("InitData"); + const auto& info = fmat.Info(); + + // initialize the row set + { + row_set_collection_.Clear(); + + // initialize histogram collection + uint32_t nbins = gmat.cut.Ptrs().back(); + hist_.Init(qu_, nbins); + + hist_buffer_.Init(qu_, nbins); + size_t buffer_size = kBufferSize; + if (buffer_size > info.num_row_ / kMinBlockSize + 1) { + buffer_size = info.num_row_ / kMinBlockSize + 1; + } + hist_buffer_.Reset(buffer_size); + + // initialize histogram builder + hist_builder_ = common::GHistBuilder(qu_, nbins); + + USMVector* row_indices = &(row_set_collection_.Data()); + row_indices->Resize(&qu_, info.num_row_); + size_t* p_row_indices = row_indices->Data(); + // mark subsample and build list of member rows + if (param_.subsample < 1.0f) { + CHECK_EQ(param_.sampling_method, xgboost::tree::TrainParam::kUniform) + << "Only uniform sampling is supported, " + << "gradient-based sampling is only support by GPU Hist."; + InitSampling(gpair, row_indices); + } else { + int has_neg_hess = 0; + const GradientPair* gpair_ptr = gpair.DataConst(); + ::sycl::event event; + { + ::sycl::buffer flag_buf(&has_neg_hess, 1); + event = qu_.submit([&](::sycl::handler& cgh) { + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(info.num_row_)), + [=](::sycl::item<1> pid) { + const size_t idx = pid.get_id(0); + p_row_indices[idx] = idx; + if (gpair_ptr[idx].GetHess() < 0.0f) { + AtomicRef has_neg_hess_ref(flag_buf_acc[0]); + has_neg_hess_ref.fetch_max(1); + } + }); + }); + } + + if (has_neg_hess) { + size_t max_idx = 0; + { + ::sycl::buffer flag_buf(&max_idx, 1); + event = qu_.submit([&](::sycl::handler& cgh) { + cgh.depends_on(event); + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(info.num_row_)), + [=](::sycl::item<1> pid) { + const size_t idx = pid.get_id(0); + if (gpair_ptr[idx].GetHess() >= 0.0f) { + AtomicRef max_idx_ref(flag_buf_acc[0]); + p_row_indices[max_idx_ref++] = idx; + } + }); + }); + } + row_indices->Resize(&qu_, max_idx, 0, &event); + } + qu_.wait_and_throw(); + } + } + row_set_collection_.Init(); + + { + /* determine layout of data */ + const size_t nrow = info.num_row_; + const size_t ncol = info.num_col_; + const size_t nnz = info.num_nonzero_; + // number of discrete bins for feature 0 + const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0]; + if (nrow * ncol == nnz) { + // dense data with zero-based indexing + data_layout_ = kDenseDataZeroBased; + } else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) { + // dense data with one-based indexing + data_layout_ = kDenseDataOneBased; + } else { + // sparse data + data_layout_ = kSparseData; + } + } + + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { + /* specialized code for dense data: + choose the column that has a least positive number of discrete bins. + For dense data (with no missing value), + the sum of gradient histogram is equal to snode[nid] */ + const std::vector& row_ptr = gmat.cut.Ptrs(); + const auto nfeature = static_cast(row_ptr.size() - 1); + uint32_t min_nbins_per_feature = 0; + for (bst_uint i = 0; i < nfeature; ++i) { + const uint32_t nbins = row_ptr[i + 1] - row_ptr[i]; + if (nbins > 0) { + if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) { + min_nbins_per_feature = nbins; + fid_least_bins_ = i; + } + } + } + CHECK_GT(min_nbins_per_feature, 0U); + } + + std::fill(snode_host_.begin(), snode_host_.end(), NodeEntry(param_)); + builder_monitor_.Stop("InitData"); +} + +template +void HistUpdater::InitNewNode(int nid, + const common::GHistIndexMatrix& gmat, + const USMVector &gpair, + const DMatrix& fmat, + const RegTree& tree) { + builder_monitor_.Start("InitNewNode"); + + snode_host_.resize(tree.NumNodes(), NodeEntry(param_)); + { + if (tree[nid].IsRoot()) { + GradStats grad_stat; + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { + const std::vector& row_ptr = gmat.cut.Ptrs(); + const uint32_t ibegin = row_ptr[fid_least_bins_]; + const uint32_t iend = row_ptr[fid_least_bins_ + 1]; + const auto* hist = reinterpret_cast*>(hist_[nid].Data()); + + std::vector> ets(iend - ibegin); + qu_.memcpy(ets.data(), hist + ibegin, + (iend - ibegin) * sizeof(GradStats)).wait_and_throw(); + for (const auto& et : ets) { + grad_stat += et; + } + } else { + const common::RowSetCollection::Elem e = row_set_collection_[nid]; + const size_t* row_idxs = e.begin; + const size_t size = e.Size(); + const GradientPair* gpair_ptr = gpair.DataConst(); + + ::sycl::buffer> buff(&grad_stat, 1); + qu_.submit([&](::sycl::handler& cgh) { + auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>()); + cgh.parallel_for<>(::sycl::range<1>(size), reduction, + [=](::sycl::item<1> pid, auto& sum) { + size_t i = pid.get_id(0); + size_t row_idx = row_idxs[i]; + if constexpr (std::is_same::value) { + sum += gpair_ptr[row_idx]; + } else { + sum += GradStats(gpair_ptr[row_idx].GetGrad(), + gpair_ptr[row_idx].GetHess()); + } + }); + }).wait_and_throw(); + } + auto rc = collective::Allreduce( + ctx_, linalg::MakeVec(reinterpret_cast(&grad_stat), 2), + collective::Op::kSum); + SafeColl(rc); + snode_host_[nid].stats = grad_stat; + } else { + int parent_id = tree[nid].Parent(); + if (tree[nid].IsLeftChild()) { + snode_host_[nid].stats = snode_host_[parent_id].best.left_sum; + } else { + snode_host_[nid].stats = snode_host_[parent_id].best.right_sum; + } + } + } + + // calculating the weights + { + auto evaluator = tree_evaluator_.GetEvaluator(); + bst_uint parentid = tree[nid].Parent(); + snode_host_[nid].weight = evaluator.CalcWeight(parentid, snode_host_[nid].stats); + snode_host_[nid].root_gain = evaluator.CalcGain(parentid, snode_host_[nid].stats); + } + builder_monitor_.Stop("InitNewNode"); +} + +template class HistUpdater; +template class HistUpdater; + +} // namespace tree +} // namespace sycl +} // namespace xgboost diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h new file mode 100644 index 000000000000..544a7c26698a --- /dev/null +++ b/plugin/sycl/tree/hist_updater.h @@ -0,0 +1,167 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file hist_updater.h + */ +#ifndef PLUGIN_SYCL_TREE_HIST_UPDATER_H_ +#define PLUGIN_SYCL_TREE_HIST_UPDATER_H_ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include +#pragma GCC diagnostic pop + +#include +#include +#include + +#include "../common/partition_builder.h" +#include "split_evaluator.h" +#include "hist_synchronizer.h" +#include "hist_row_adder.h" + +#include "../data.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +// data structure +template +struct NodeEntry { + /*! \brief statics for node entry */ + GradStats stats; + /*! \brief loss of this node, without split */ + GradType root_gain; + /*! \brief weight calculated related to current data */ + GradType weight; + /*! \brief current best solution */ + SplitEntry best; + // constructor + explicit NodeEntry(const xgboost::tree::TrainParam& param) + : root_gain(0.0f), weight(0.0f) {} +}; + +template +class HistUpdater { + public: + template + using GHistRowT = common::GHistRow; + using GradientPairT = xgboost::detail::GradientPairInternal; + + explicit HistUpdater(const Context* ctx, + ::sycl::queue qu, + const xgboost::tree::TrainParam& param, + std::unique_ptr pruner, + FeatureInteractionConstraintHost int_constraints_, + DMatrix const* fmat) + : ctx_(ctx), qu_(qu), param_(param), + tree_evaluator_(qu, param, fmat->Info().num_col_), + pruner_(std::move(pruner)), + interaction_constraints_{std::move(int_constraints_)}, + p_last_tree_(nullptr), p_last_fmat_(fmat) { + builder_monitor_.Init("SYCL::Quantile::HistUpdater"); + kernel_monitor_.Init("SYCL::Quantile::HistUpdater"); + const auto sub_group_sizes = + qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>(); + sub_group_size_ = sub_group_sizes.back(); + } + + void SetHistSynchronizer(HistSynchronizer* sync); + void SetHistRowsAdder(HistRowsAdder* adder); + + protected: + friend class BatchHistSynchronizer; + friend class BatchHistRowsAdder; + + void InitSampling(const USMVector &gpair, + USMVector* row_indices); + + + void InitData(const common::GHistIndexMatrix& gmat, + const USMVector &gpair, + const DMatrix& fmat, + const RegTree& tree); + + inline ::sycl::event BuildHist( + const USMVector& gpair_device, + const common::RowSetCollection::Elem row_indices, + const common::GHistIndexMatrix& gmat, + GHistRowT* hist, + GHistRowT* hist_buffer, + ::sycl::event event_priv) { + return hist_builder_.BuildHist(gpair_device, row_indices, gmat, hist, + data_layout_ != kSparseData, hist_buffer, event_priv); + } + + void InitNewNode(int nid, + const common::GHistIndexMatrix& gmat, + const USMVector &gpair, + const DMatrix& fmat, + const RegTree& tree); + + void BuildLocalHistograms(const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + const USMVector &gpair); + + void BuildHistogramsLossGuide( + ExpandEntry entry, + const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + const USMVector &gpair); + + // --data fields-- + const Context* ctx_; + size_t sub_group_size_; + + // the internal row sets + common::RowSetCollection row_set_collection_; + + const xgboost::tree::TrainParam& param_; + TreeEvaluator tree_evaluator_; + std::unique_ptr pruner_; + FeatureInteractionConstraintHost interaction_constraints_; + + // back pointers to tree and data matrix + const RegTree* p_last_tree_; + DMatrix const* const p_last_fmat_; + + enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; + DataLayout data_layout_; + + constexpr static size_t kBufferSize = 2048; + constexpr static size_t kMinBlockSize = 128; + common::GHistBuilder hist_builder_; + common::ParallelGHistBuilder hist_buffer_; + /*! \brief culmulative histogram of gradients. */ + common::HistCollection hist_; + + /*! \brief TreeNode Data: statistics for each constructed node */ + std::vector> snode_host_; + + xgboost::common::Monitor builder_monitor_; + xgboost::common::Monitor kernel_monitor_; + + /*! \brief feature with least # of bins. to be used for dense specialization + of InitNewNode() */ + uint32_t fid_least_bins_; + + uint64_t seed_ = 0; + + // key is the node id which should be calculated by Subtraction Trick, value is the node which + // provides the evidence for substracts + std::vector nodes_for_subtraction_trick_; + // list of nodes whose histograms would be built explicitly. + std::vector nodes_for_explicit_hist_build_; + + std::unique_ptr> hist_synchronizer_; + std::unique_ptr> hist_rows_adder_; + + ::sycl::queue qu_; +}; + +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_HIST_UPDATER_H_ diff --git a/plugin/sycl/tree/param.h b/plugin/sycl/tree/param.h new file mode 100644 index 000000000000..a83a7ad138ab --- /dev/null +++ b/plugin/sycl/tree/param.h @@ -0,0 +1,164 @@ +/*! + * Copyright 2014-2024 by Contributors + */ +#ifndef PLUGIN_SYCL_TREE_PARAM_H_ +#define PLUGIN_SYCL_TREE_PARAM_H_ + + +#include +#include +#include +#include +#include + + +#include "xgboost/parameter.h" +#include "xgboost/data.h" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#include "../src/tree/param.h" +#pragma GCC diagnostic pop + +#include + +namespace xgboost { +namespace sycl { +namespace tree { + + +/*! \brief Wrapper for necessary training parameters for regression tree to access on device */ +/* The original structure xgboost::tree::TrainParam can't be used, + * since std::vector are not copyable on sycl-devices. + */ +struct TrainParam { + float min_child_weight; + float reg_lambda; + float reg_alpha; + float max_delta_step; + + TrainParam() {} + + explicit TrainParam(const xgboost::tree::TrainParam& param) { + reg_lambda = param.reg_lambda; + reg_alpha = param.reg_alpha; + min_child_weight = param.min_child_weight; + max_delta_step = param.max_delta_step; + } +}; + +template +using GradStats = xgboost::detail::GradientPairInternal; + +/*! + * \brief SYCL implementation of SplitEntryContainer for device compilation. + * Original structure cannot be used due 'cat_bits' field of type std::vector, + * which is not device-copyable + */ +template +struct SplitEntryContainer { + /*! \brief loss change after split this node */ + bst_float loss_chg {0.0f}; + /*! \brief split index */ + bst_feature_t sindex{0}; + bst_float split_value{0.0f}; + + + GradientT left_sum; + GradientT right_sum; + + + SplitEntryContainer() = default; + + + friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) { + os << "loss_chg: " << s.loss_chg << ", " + << "split index: " << s.SplitIndex() << ", " + << "split value: " << s.split_value << ", " + << "left_sum: " << s.left_sum << ", " + << "right_sum: " << s.right_sum; + return os; + } + /*!\return feature index to split on */ + bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); } + /*!\return whether missing value goes to left branch */ + bool DefaultLeft() const { return (sindex >> 31) != 0; } + /*! + * \brief decides whether we can replace current entry with the given statistics + * + * This function gives better priority to lower index when loss_chg == new_loss_chg. + * Not the best way, but helps to give consistent result during multi-thread + * execution. + * + * \param new_loss_chg the loss reduction get through the split + * \param split_index the feature index where the split is on + */ + inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const { + if (::sycl::isinf(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf, + // for example when lambda = 0 & min_child_weight = 0 + // skip value in this case + return false; + } else if (this->SplitIndex() <= split_index) { + return new_loss_chg > this->loss_chg; + } else { + return !(this->loss_chg > new_loss_chg); + } + } + /*! + * \brief update the split entry, replace it if e is better + * \param e candidate split solution + * \return whether the proposed split is better and can replace current split + */ + inline bool Update(const SplitEntryContainer &e) { + if (this->NeedReplace(e.loss_chg, e.SplitIndex())) { + this->loss_chg = e.loss_chg; + this->sindex = e.sindex; + this->split_value = e.split_value; + this->left_sum = e.left_sum; + this->right_sum = e.right_sum; + return true; + } else { + return false; + } + } + /*! + * \brief update the split entry, replace it if e is better + * \param new_loss_chg loss reduction of new candidate + * \param split_index feature index to split on + * \param new_split_value the split point + * \param default_left whether the missing value goes to left + * \return whether the proposed split is better and can replace current split + */ + bool Update(bst_float new_loss_chg, unsigned split_index, + bst_float new_split_value, bool default_left, + const GradientT &left_sum, + const GradientT &right_sum) { + if (this->NeedReplace(new_loss_chg, split_index)) { + this->loss_chg = new_loss_chg; + if (default_left) { + split_index |= (1U << 31); + } + this->sindex = split_index; + this->split_value = new_split_value; + this->left_sum = left_sum; + this->right_sum = right_sum; + return true; + } else { + return false; + } + } + + + /*! \brief same as update, used by AllReduce*/ + inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*) + const SplitEntryContainer &src) { // NOLINT(*) + dst.Update(src); + } +}; + +template +using SplitEntry = SplitEntryContainer>; + +} // namespace tree +} // namespace sycl +} // namespace xgboost +#endif // PLUGIN_SYCL_TREE_PARAM_H_ diff --git a/plugin/sycl/tree/split_evaluator.h b/plugin/sycl/tree/split_evaluator.h new file mode 100644 index 000000000000..2f1e8c7c4e66 --- /dev/null +++ b/plugin/sycl/tree/split_evaluator.h @@ -0,0 +1,208 @@ +/*! + * Copyright 2018-2024 by Contributors + */ + +#ifndef PLUGIN_SYCL_TREE_SPLIT_EVALUATOR_H_ +#define PLUGIN_SYCL_TREE_SPLIT_EVALUATOR_H_ + +#include +#include +#include +#include +#include + +#include "param.h" +#include "../data.h" + +#include "xgboost/tree_model.h" +#include "xgboost/host_device_vector.h" +#include "xgboost/context.h" +#include "../../src/common/transform.h" +#include "../../src/common/math.h" +#include "../../src/tree/param.h" + +#include + +namespace xgboost { +namespace sycl { +namespace tree { + +/*! \brief SYCL implementation of TreeEvaluator, with USM memory for temporary buffer to access on device. + * It also contains own implementation of SplitEvaluator for device compilation, because some of the + functions from the original SplitEvaluator are currently not supported + */ + +template +class TreeEvaluator { + // hist and exact use parent id to calculate constraints. + static constexpr bst_node_t kRootParentId = + (-1 & static_cast((1U << 31) - 1)); + + USMVector lower_bounds_; + USMVector upper_bounds_; + USMVector monotone_; + TrainParam param_; + ::sycl::queue qu_; + bool has_constraint_; + + public: + void Reset(::sycl::queue qu, xgboost::tree::TrainParam const& p, bst_feature_t n_features) { + qu_ = qu; + + has_constraint_ = false; + for (const auto& constraint : p.monotone_constraints) { + if (constraint != 0) { + has_constraint_ = true; + break; + } + } + + if (has_constraint_) { + monotone_.Resize(&qu_, n_features, 0); + qu_.memcpy(monotone_.Data(), p.monotone_constraints.data(), + sizeof(int) * p.monotone_constraints.size()); + qu_.wait(); + + lower_bounds_.Resize(&qu_, p.MaxNodes(), std::numeric_limits::lowest()); + upper_bounds_.Resize(&qu_, p.MaxNodes(), std::numeric_limits::max()); + } + param_ = TrainParam(p); + } + + bool HasConstraint() const { + return has_constraint_; + } + + TreeEvaluator(::sycl::queue qu, xgboost::tree::TrainParam const& p, bst_feature_t n_features) { + Reset(qu, p, n_features); + } + + struct SplitEvaluator { + const int* constraints; + const GradType* lower; + const GradType* upper; + bool has_constraint; + TrainParam param; + + GradType CalcSplitGain(bst_node_t nidx, + bst_feature_t fidx, + const GradStats& left, + const GradStats& right) const { + const GradType negative_infinity = -std::numeric_limits::infinity(); + GradType wleft = this->CalcWeight(nidx, left); + GradType wright = this->CalcWeight(nidx, right); + + GradType gain = this->CalcGainGivenWeight(nidx, left, wleft) + + this->CalcGainGivenWeight(nidx, right, wright); + if (!has_constraint) { + return gain; + } + + int constraint = constraints[fidx]; + if (constraint == 0) { + return gain; + } else if (constraint > 0) { + return wleft <= wright ? gain : negative_infinity; + } else { + return wleft >= wright ? gain : negative_infinity; + } + } + + inline static GradType ThresholdL1(GradType w, float alpha) { + if (w > + alpha) { + return w - alpha; + } + if (w < - alpha) { + return w + alpha; + } + return 0.0; + } + + inline GradType CalcWeight(GradType sum_grad, GradType sum_hess) const { + if (sum_hess < param.min_child_weight || sum_hess <= 0.0) { + return 0.0; + } + GradType dw = -this->ThresholdL1(sum_grad, param.reg_alpha) / (sum_hess + param.reg_lambda); + if (param.max_delta_step != 0.0f && std::abs(dw) > param.max_delta_step) { + dw = ::sycl::copysign((GradType)param.max_delta_step, dw); + } + return dw; + } + + inline GradType CalcWeight(bst_node_t nodeid, const GradStats& stats) const { + GradType w = this->CalcWeight(stats.GetGrad(), stats.GetHess()); + if (!has_constraint) { + return w; + } + + if (nodeid == kRootParentId) { + return w; + } else if (w < lower[nodeid]) { + return lower[nodeid]; + } else if (w > upper[nodeid]) { + return upper[nodeid]; + } else { + return w; + } + } + + inline GradType CalcGainGivenWeight(GradType sum_grad, GradType sum_hess, GradType w) const { + return -(2.0f * sum_grad * w + (sum_hess + param.reg_lambda) * xgboost::common::Sqr(w)); + } + + inline GradType CalcGainGivenWeight(bst_node_t nid, const GradStats& stats, + GradType w) const { + if (stats.GetHess() <= 0) { + return .0f; + } + // Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error. + if (param.max_delta_step == 0.0f && has_constraint == false) { + return xgboost::common::Sqr(this->ThresholdL1(stats.GetGrad(), param.reg_alpha)) / + (stats.GetHess() + param.reg_lambda); + } + return this->CalcGainGivenWeight(stats.GetGrad(), stats.GetHess(), w); + } + + GradType CalcGain(bst_node_t nid, const GradStats& stats) const { + return this->CalcGainGivenWeight(nid, stats, this->CalcWeight(nid, stats)); + } + }; + + public: + /* Get a view to the evaluator that can be passed down to device. */ + auto GetEvaluator() const { + return SplitEvaluator{monotone_.DataConst(), + lower_bounds_.DataConst(), + upper_bounds_.DataConst(), + has_constraint_, + param_}; + } + + void AddSplit(bst_node_t nodeid, bst_node_t leftid, bst_node_t rightid, + bst_feature_t f, GradType left_weight, GradType right_weight) { + if (!has_constraint_) { + return; + } + + lower_bounds_[leftid] = lower_bounds_[nodeid]; + upper_bounds_[leftid] = upper_bounds_[nodeid]; + + lower_bounds_[rightid] = lower_bounds_[nodeid]; + upper_bounds_[rightid] = upper_bounds_[nodeid]; + int32_t c = monotone_[f]; + GradType mid = (left_weight + right_weight) / 2; + + if (c < 0) { + lower_bounds_[leftid] = mid; + upper_bounds_[rightid] = mid; + } else if (c > 0) { + upper_bounds_[leftid] = mid; + lower_bounds_[rightid] = mid; + } + } +}; +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_SPLIT_EVALUATOR_H_ diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc new file mode 100644 index 000000000000..98a42c3c8ba0 --- /dev/null +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -0,0 +1,55 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file updater_quantile_hist.cc + */ +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include "xgboost/tree_updater.h" +#pragma GCC diagnostic pop + +#include "xgboost/logging.h" + +#include "updater_quantile_hist.h" +#include "../data.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +DMLC_REGISTRY_FILE_TAG(updater_quantile_hist_sycl); + +DMLC_REGISTER_PARAMETER(HistMakerTrainParam); + +void QuantileHistMaker::Configure(const Args& args) { + const DeviceOrd device_spec = ctx_->Device(); + qu_ = device_manager.GetQueue(device_spec); + + param_.UpdateAllowUnknown(args); + hist_maker_param_.UpdateAllowUnknown(args); +} + +void QuantileHistMaker::Update(xgboost::tree::TrainParam const *param, + linalg::Matrix* gpair, + DMatrix *dmat, + xgboost::common::Span> out_position, + const std::vector &trees) { + LOG(FATAL) << "Not Implemented yet"; +} + +bool QuantileHistMaker::UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView out_preds) { + LOG(FATAL) << "Not Implemented yet"; +} + +XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker_sycl") +.describe("Grow tree using quantized histogram with SYCL.") +.set_body( + [](Context const* ctx, ObjInfo const * task) { + return new QuantileHistMaker(ctx, task); + }); +} // namespace tree +} // namespace sycl +} // namespace xgboost diff --git a/plugin/sycl/tree/updater_quantile_hist.h b/plugin/sycl/tree/updater_quantile_hist.h new file mode 100644 index 000000000000..93a50de3e449 --- /dev/null +++ b/plugin/sycl/tree/updater_quantile_hist.h @@ -0,0 +1,91 @@ +/*! + * Copyright 2017-2024 by Contributors + * \file updater_quantile_hist.h + */ +#ifndef PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ +#define PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ + +#include +#include + +#include + +#include "../data/gradient_index.h" +#include "../common/hist_util.h" +#include "../common/row_set.h" +#include "../common/partition_builder.h" +#include "split_evaluator.h" +#include "../device_manager.h" + +#include "xgboost/data.h" +#include "xgboost/json.h" +#include "../../src/tree/constraints.h" +#include "../../src/common/random.h" + +namespace xgboost { +namespace sycl { +namespace tree { + +// training parameters specific to this algorithm +struct HistMakerTrainParam + : public XGBoostParameter { + bool single_precision_histogram = false; + // declare parameters + DMLC_DECLARE_PARAMETER(HistMakerTrainParam) { + DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( + "Use single precision to build histograms."); + } +}; + +/*! \brief construct a tree using quantized feature values with SYCL backend*/ +class QuantileHistMaker: public TreeUpdater { + public: + QuantileHistMaker(Context const* ctx, ObjInfo const * task) : + TreeUpdater(ctx), task_{task} { + updater_monitor_.Init("SYCLQuantileHistMaker"); + } + void Configure(const Args& args) override; + + void Update(xgboost::tree::TrainParam const *param, + linalg::Matrix* gpair, + DMatrix* dmat, + xgboost::common::Span> out_position, + const std::vector& trees) override; + + bool UpdatePredictionCache(const DMatrix* data, + linalg::MatrixView out_preds) override; + + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("train_param"), &this->param_); + FromJson(config.at("sycl_hist_train_param"), &this->hist_maker_param_); + } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["train_param"] = ToJson(param_); + out["sycl_hist_train_param"] = ToJson(hist_maker_param_); + } + + char const* Name() const override { + return "grow_quantile_histmaker_sycl"; + } + + protected: + HistMakerTrainParam hist_maker_param_; + // training parameter + xgboost::tree::TrainParam param_; + + xgboost::common::Monitor updater_monitor_; + + ::sycl::queue qu_; + DeviceManager device_manager; + ObjInfo const *task_{nullptr}; +}; + + +} // namespace tree +} // namespace sycl +} // namespace xgboost + +#endif // PLUGIN_SYCL_TREE_UPDATER_QUANTILE_HIST_H_ diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index a41d296bf2a6..468d38942f73 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -1,17 +1,17 @@ """XGBoost collective communication related API.""" import ctypes -import json import logging import os import pickle +import platform from enum import IntEnum, unique from typing import Any, Dict, List, Optional import numpy as np from ._typing import _T -from .core import _LIB, _check_call, build_info, c_str, from_pystr_to_cstr, py_str +from .core import _LIB, _check_call, build_info, c_str, make_jcargs, py_str LOGGER = logging.getLogger("[xgboost.collective]") @@ -21,49 +21,35 @@ def init(**args: Any) -> None: Parameters ---------- - args: Dict[str, Any] + args : Keyword arguments representing the parameters and their values. Accepted parameters: - - xgboost_communicator: The type of the communicator. Can be set as an environment - variable. + - dmlc_communicator: The type of the communicator. * rabit: Use Rabit. This is the default if the type is unspecified. * federated: Use the gRPC interface for Federated Learning. - Only applicable to the Rabit communicator (these are case sensitive): - -- rabit_tracker_uri: Hostname of the tracker. - -- rabit_tracker_port: Port number of the tracker. - -- rabit_task_id: ID of the current task, can be used to obtain deterministic rank - assignment. - -- rabit_world_size: Total number of workers. - -- rabit_hadoop_mode: Enable Hadoop support. - -- rabit_tree_reduce_minsize: Minimal size for tree reduce. - -- rabit_reduce_ring_mincount: Minimal count to perform ring reduce. - -- rabit_reduce_buffer: Size of the reduce buffer. - -- rabit_bootstrap_cache: Size of the bootstrap cache. - -- rabit_debug: Enable debugging. - -- rabit_timeout: Enable timeout. - -- rabit_timeout_sec: Timeout in seconds. - -- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. - Only applicable to the Rabit communicator (these are case-sensitive, and can be set as - environment variables): - -- DMLC_TRACKER_URI: Hostname of the tracker. - -- DMLC_TRACKER_PORT: Port number of the tracker. - -- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank - assignment. - -- DMLC_ROLE: Role of the current task, "worker" or "server". - -- DMLC_NUM_ATTEMPT: Number of attempts after task failure. - -- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. - Only applicable to the Federated communicator (use upper case for environment variables, use - lower case for runtime configuration): - -- federated_server_address: Address of the federated server. - -- federated_world_size: Number of federated workers. - -- federated_rank: Rank of the current worker. - -- federated_server_cert: Server certificate file path. Only needed for the SSL mode. - -- federated_client_key: Client key file path. Only needed for the SSL mode. - -- federated_client_cert: Client certificate file path. Only needed for the SSL mode. + + Only applicable to the Rabit communicator: + - dmlc_tracker_uri: Hostname of the tracker. + - dmlc_tracker_port: Port number of the tracker. + - dmlc_task_id: ID of the current task, can be used to obtain deterministic + - dmlc_retry: The number of retry when handling network errors. + - dmlc_timeout: Timeout in seconds. + - dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication. + + Only applicable to the Federated communicator (use upper case for environment + variables, use lower case for runtime configuration): + + - federated_server_address: Address of the federated server. + - federated_world_size: Number of federated workers. + - federated_rank: Rank of the current worker. + - federated_server_cert: Server certificate file path. Only needed for the SSL + mode. + - federated_client_key: Client key file path. Only needed for the SSL mode. + - federated_client_cert: Client certificate file path. Only needed for the SSL + mode. """ - config = from_pystr_to_cstr(json.dumps(args)) - _check_call(_LIB.XGCommunicatorInit(config)) + _check_call(_LIB.XGCommunicatorInit(make_jcargs(**args))) def finalize() -> None: @@ -157,7 +143,7 @@ def broadcast(data: _T, root: int) -> _T: assert data is not None, "need to pass in data when broadcasting" s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) length.value = len(s) - # run first broadcast + # Run first broadcast _check_call( _LIB.XGCommunicatorBroadcast( ctypes.byref(length), ctypes.sizeof(ctypes.c_ulong), root @@ -184,16 +170,27 @@ def broadcast(data: _T, root: int) -> _T: # enumeration of dtypes -DTYPE_ENUM__ = { - np.dtype("int8"): 0, - np.dtype("uint8"): 1, - np.dtype("int32"): 2, - np.dtype("uint32"): 3, - np.dtype("int64"): 4, - np.dtype("uint64"): 5, - np.dtype("float32"): 6, - np.dtype("float64"): 7, -} +def _map_dtype(dtype: np.dtype) -> int: + dtype_map = { + np.dtype("float16"): 0, + np.dtype("float32"): 1, + np.dtype("float64"): 2, + np.dtype("int8"): 4, + np.dtype("int16"): 5, + np.dtype("int32"): 6, + np.dtype("int64"): 7, + np.dtype("uint8"): 8, + np.dtype("uint16"): 9, + np.dtype("uint32"): 10, + np.dtype("uint64"): 11, + } + if platform.system() != "Windows": + dtype_map.update({np.dtype("float128"): 3}) + + if dtype not in dtype_map: + raise TypeError(f"data type {dtype} is not supported on the current platform.") + + return dtype_map[dtype] @unique @@ -229,24 +226,23 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid """ if not isinstance(data, np.ndarray): raise TypeError("allreduce only takes in numpy.ndarray") - buf = data.ravel() - if buf.base is data.base: - buf = buf.copy() - if buf.dtype not in DTYPE_ENUM__: - raise TypeError(f"data type {buf.dtype} not supported") + buf = data.ravel().copy() _check_call( _LIB.XGCommunicatorAllreduce( buf.ctypes.data_as(ctypes.c_void_p), buf.size, - DTYPE_ENUM__[buf.dtype], + _map_dtype(buf.dtype), int(op), - None, - None, ) ) return buf +def signal_error() -> None: + """Kill the process.""" + _check_call(_LIB.XGCommunicatorSignalError()) + + class CommunicatorContext: """A context controlling collective communicator initialization and finalization.""" diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 36e4bdcf0d2d..76251d65c522 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -295,7 +295,7 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None: if device and device.find(":") != -1: raise ValueError( "Distributed training doesn't support selecting device ordinal as GPUs are" - " managed by the distributed framework. use `device=cuda` or `device=gpu`" + " managed by the distributed frameworks. use `device=cuda` or `device=gpu`" " instead." ) @@ -504,8 +504,10 @@ class DataIter(ABC): # pylint: disable=too-many-instance-attributes cache_prefix : Prefix to the cache files, only used in external memory. release_data : - Whether the iterator should release the data during reset. Set it to True if the - data transformation (converting data to np.float32 type) is expensive. + Whether the iterator should release the data during iteration. Set it to True if + the data transformation (converting data to np.float32 type) is memory + intensive. Otherwise, if the transformation is computation intensive then we can + keep the cache. """ @@ -517,15 +519,12 @@ def __init__( self._handle = _ProxyDMatrix() self._exception: Optional[Exception] = None self._enable_categorical = False - self._allow_host = True self._release = release_data # Stage data in Python until reset or next is called to avoid data being free. self._temporary_data: Optional[TransformedData] = None self._data_ref: Optional[weakref.ReferenceType] = None - def get_callbacks( - self, allow_host: bool, enable_categorical: bool - ) -> Tuple[Callable, Callable]: + def get_callbacks(self, enable_categorical: bool) -> Tuple[Callable, Callable]: """Get callback functions for iterating in C. This is an internal function.""" assert hasattr(self, "cache_prefix"), "__init__ is not called." self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)( @@ -535,7 +534,6 @@ def get_callbacks( ctypes.c_int, ctypes.c_void_p, )(self._next_wrapper) - self._allow_host = allow_host self._enable_categorical = enable_categorical return self._reset_callback, self._next_callback @@ -624,7 +622,7 @@ def input_data( ) # Stage the data, meta info are copied inside C++ MetaInfo. self._temporary_data = (new, cat_codes, feature_names, feature_types) - dispatch_proxy_set_data(self.proxy, new, cat_codes, self._allow_host) + dispatch_proxy_set_data(self.proxy, new, cat_codes) self.proxy.set_info( feature_names=feature_names, feature_types=feature_types, @@ -632,6 +630,9 @@ def input_data( ) self._data_ref = ref + # Release the data before next batch is loaded. + if self._release: + self._temporary_data = None # pylint: disable=not-callable return self._handle_exception(lambda: self.next(input_data), 0) @@ -911,7 +912,7 @@ def _init_from_iter(self, iterator: DataIter, enable_categorical: bool) -> None: } args_cstr = from_pystr_to_cstr(json.dumps(args)) handle = ctypes.c_void_p() - reset_callback, next_callback = it.get_callbacks(True, enable_categorical) + reset_callback, next_callback = it.get_callbacks(enable_categorical) ret = _LIB.XGDMatrixCreateFromCallback( None, it.proxy.handle, @@ -1437,37 +1438,37 @@ def __init__(self) -> None: # pylint: disable=super-init-not-called self.handle = ctypes.c_void_p() _check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle))) - def _set_data_from_cuda_interface(self, data: DataType) -> None: - """Set data from CUDA array interface.""" + def _ref_data_from_cuda_interface(self, data: DataType) -> None: + """Reference data from CUDA array interface.""" interface = data.__cuda_array_interface__ interface_str = bytes(json.dumps(interface), "utf-8") _check_call( _LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str) ) - def _set_data_from_cuda_columnar(self, data: DataType, cat_codes: list) -> None: - """Set data from CUDA columnar format.""" + def _ref_data_from_cuda_columnar(self, data: DataType, cat_codes: list) -> None: + """Reference data from CUDA columnar format.""" from .data import _cudf_array_interfaces interfaces_str = _cudf_array_interfaces(data, cat_codes) _check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str)) - def _set_data_from_array(self, data: np.ndarray) -> None: - """Set data from numpy array.""" + def _ref_data_from_array(self, data: np.ndarray) -> None: + """Reference data from numpy array.""" from .data import _array_interface _check_call( _LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data)) ) - def _set_data_from_pandas(self, data: DataType) -> None: - """Set data from a pandas DataFrame. The input is a PandasTransformed instance.""" + def _ref_data_from_pandas(self, data: DataType) -> None: + """Reference data from a pandas DataFrame. The input is a PandasTransformed instance.""" _check_call( _LIB.XGProxyDMatrixSetDataColumnar(self.handle, data.array_interface()) ) - def _set_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None: - """Set data from scipy csr""" + def _ref_data_from_csr(self, csr: scipy.sparse.csr_matrix) -> None: + """Reference data from scipy csr.""" from .data import _array_interface _LIB.XGProxyDMatrixSetDataCSR( @@ -1609,7 +1610,7 @@ def _init( it = SingleBatchInternalIter(data=data, **meta) handle = ctypes.c_void_p() - reset_callback, next_callback = it.get_callbacks(True, enable_categorical) + reset_callback, next_callback = it.get_callbacks(enable_categorical) if it.cache_prefix is not None: raise ValueError( "QuantileDMatrix doesn't cache data, remove the cache_prefix " diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py index a3c549a2e515..c74caecdb4df 100644 --- a/python-package/xgboost/dask/__init__.py +++ b/python-package/xgboost/dask/__init__.py @@ -71,6 +71,7 @@ Metric, Objective, QuantileDMatrix, + XGBoostError, _check_distributed_params, _deprecate_positional_args, _expect, @@ -90,7 +91,7 @@ _wrap_evaluation_matrices, xgboost_model_doc, ) -from xgboost.tracker import RabitTracker, get_host_ip +from xgboost.tracker import RabitTracker from xgboost.training import train as worker_train from .utils import get_n_threads @@ -160,36 +161,38 @@ def _try_start_tracker( n_workers: int, addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]], ) -> Dict[str, Union[int, str]]: - env: Dict[str, Union[int, str]] = {"DMLC_NUM_WORKER": n_workers} + env: Dict[str, Union[int, str]] = {} try: if isinstance(addrs[0], tuple): host_ip = addrs[0][0] port = addrs[0][1] rabit_tracker = RabitTracker( - host_ip=get_host_ip(host_ip), n_workers=n_workers, + host_ip=host_ip, port=port, - use_logger=False, + sortby="task", ) else: addr = addrs[0] assert isinstance(addr, str) or addr is None - host_ip = get_host_ip(addr) rabit_tracker = RabitTracker( - host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task" + n_workers=n_workers, host_ip=addr, sortby="task" ) - env.update(rabit_tracker.worker_envs()) - rabit_tracker.start(n_workers) - thread = Thread(target=rabit_tracker.join) + + rabit_tracker.start() + thread = Thread(target=rabit_tracker.wait_for) thread.daemon = True thread.start() - except socket.error as e: - if len(addrs) < 2 or e.errno != 99: + env.update(rabit_tracker.worker_args()) + + except XGBoostError as e: + if len(addrs) < 2: raise LOGGER.warning( - "Failed to bind address '%s', trying to use '%s' instead.", + "Failed to bind address '%s', trying to use '%s' instead. Error:\n %s", str(addrs[0]), str(addrs[1]), + str(e), ) env = _try_start_tracker(n_workers, addrs[1:]) @@ -616,7 +619,7 @@ def __init__( assert isinstance(self._label_upper_bound, types) self._iter = 0 # set iterator to 0 - super().__init__() + super().__init__(release_data=True) def _get(self, attr: str) -> Optional[Any]: if getattr(self, attr) is not None: diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 07a08dc5f0b2..7e0ae793ba6e 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -233,9 +233,9 @@ def _maybe_np_slice(data: DataType, dtype: Optional[NumpyDType]) -> np.ndarray: if not data.flags.c_contiguous: data = np.array(data, copy=True, dtype=dtype) else: - data = np.array(data, copy=False, dtype=dtype) + data = np.asarray(data, dtype=dtype) except AttributeError: - data = np.array(data, copy=False, dtype=dtype) + data = np.asarray(data, dtype=dtype) data, dtype = _ensure_np_dtype(data, dtype) return data @@ -370,10 +370,8 @@ def pandas_feature_info( if feature_names is None and meta is None: if isinstance(data.columns, pd.MultiIndex): feature_names = [" ".join([str(x) for x in i]) for i in data.columns] - elif isinstance(data.columns, (pd.Index, pd.RangeIndex)): - feature_names = list(map(str, data.columns)) else: - feature_names = data.columns.format() + feature_names = list(data.columns.map(str)) # handle feature types if feature_types is None and meta is None: @@ -483,7 +481,7 @@ def cat_codes(ser: pd.Series) -> np.ndarray: if is_pd_cat_dtype(ser.dtype): return _ensure_np_dtype( ser.cat.codes.astype(np.float32) - .replace(-1.0, np.NaN) + .replace(-1.0, np.nan) .to_numpy(na_value=np.nan), np.float32, )[0] @@ -495,7 +493,7 @@ def cat_codes(ser: pd.Series) -> np.ndarray: .combine_chunks() .dictionary_encode() .indices.astype(np.float32) - .replace(-1.0, np.NaN) + .replace(-1.0, np.nan) ) def nu_type(ser: pd.Series) -> np.ndarray: @@ -865,6 +863,22 @@ def _is_cudf_df(data: DataType) -> bool: return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame") +def _get_cudf_cat_predicate() -> Callable[[Any], bool]: + try: + from cudf import CategoricalDtype + + def is_categorical_dtype(dtype: Any) -> bool: + return isinstance(dtype, CategoricalDtype) + + except ImportError: + try: + from cudf.api.types import is_categorical_dtype # type: ignore + except ImportError: + from cudf.utils.dtypes import is_categorical_dtype # type: ignore + + return is_categorical_dtype + + def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes: """Extract CuDF __cuda_array_interface__. This is special as it returns a new list of data and a list of array interfaces. The data is list of categorical codes that @@ -872,11 +886,7 @@ def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes: array interface is finished. """ - try: - from cudf.api.types import is_categorical_dtype - except ImportError: - from cudf.utils.dtypes import is_categorical_dtype - + is_categorical_dtype = _get_cudf_cat_predicate() interfaces = [] def append(interface: dict) -> None: @@ -908,10 +918,21 @@ def _transform_cudf_df( feature_types: Optional[FeatureTypes], enable_categorical: bool, ) -> Tuple[ctypes.c_void_p, list, Optional[FeatureNames], Optional[FeatureTypes]]: + try: - from cudf.api.types import is_categorical_dtype + from cudf.api.types import is_bool_dtype except ImportError: - from cudf.utils.dtypes import is_categorical_dtype + from pandas.api.types import is_bool_dtype + + is_categorical_dtype = _get_cudf_cat_predicate() + # Work around https://github.com/dmlc/xgboost/issues/10181 + if _is_cudf_ser(data): + if is_bool_dtype(data.dtype): + data = data.astype(np.uint8) + else: + data = data.astype( + {col: np.uint8 for col in data.select_dtypes(include="bool")} + ) if _is_cudf_ser(data): dtypes = [data.dtype] @@ -931,15 +952,8 @@ def _transform_cudf_df( feature_names = [data.name] elif lazy_isinstance(data.columns, "cudf.core.multiindex", "MultiIndex"): feature_names = [" ".join([str(x) for x in i]) for i in data.columns] - elif ( - lazy_isinstance(data.columns, "cudf.core.index", "RangeIndex") - or lazy_isinstance(data.columns, "cudf.core.index", "Int64Index") - # Unique to cuDF, no equivalence in pandas 1.3.3 - or lazy_isinstance(data.columns, "cudf.core.index", "Int32Index") - ): - feature_names = list(map(str, data.columns)) else: - feature_names = data.columns.format() + feature_names = list(data.columns.map(str)) # handle feature types if feature_types is None: @@ -1453,7 +1467,6 @@ def dispatch_proxy_set_data( proxy: _ProxyDMatrix, data: DataType, cat_codes: Optional[list], - allow_host: bool, ) -> None: """Dispatch for QuantileDMatrix.""" if not _is_cudf_ser(data) and not _is_pandas_series(data): @@ -1461,33 +1474,30 @@ def dispatch_proxy_set_data( if _is_cudf_df(data): # pylint: disable=W0212 - proxy._set_data_from_cuda_columnar(data, cast(List, cat_codes)) + proxy._ref_data_from_cuda_columnar(data, cast(List, cat_codes)) return if _is_cudf_ser(data): # pylint: disable=W0212 - proxy._set_data_from_cuda_columnar(data, cast(List, cat_codes)) + proxy._ref_data_from_cuda_columnar(data, cast(List, cat_codes)) return if _is_cupy_alike(data): - proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 + proxy._ref_data_from_cuda_interface(data) # pylint: disable=W0212 return if _is_dlpack(data): data = _transform_dlpack(data) - proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212 + proxy._ref_data_from_cuda_interface(data) # pylint: disable=W0212 return - - err = TypeError("Value type is not supported for data iterator:" + str(type(data))) - - if not allow_host: - raise err - + # Host if isinstance(data, PandasTransformed): - proxy._set_data_from_pandas(data) # pylint: disable=W0212 + proxy._ref_data_from_pandas(data) # pylint: disable=W0212 return if _is_np_array_like(data): _check_data_shape(data) - proxy._set_data_from_array(data) # pylint: disable=W0212 + proxy._ref_data_from_array(data) # pylint: disable=W0212 return if is_scipy_csr(data): - proxy._set_data_from_csr(data) # pylint: disable=W0212 + proxy._ref_data_from_csr(data) # pylint: disable=W0212 return + + err = TypeError("Value type is not supported for data iterator:" + str(type(data))) raise err diff --git a/python-package/xgboost/federated.py b/python-package/xgboost/federated.py index 0214e4e2066a..dcba9ec81a68 100644 --- a/python-package/xgboost/federated.py +++ b/python-package/xgboost/federated.py @@ -1,45 +1,85 @@ -"""XGBoost Federated Learning related API.""" +"""XGBoost Experimental Federated Learning related API.""" -from .core import _LIB, XGBoostError, _check_call, build_info, c_str +import ctypes +from threading import Thread +from typing import Any, Dict, Optional +from .core import _LIB, _check_call, make_jcargs +from .tracker import RabitTracker -def run_federated_server( - port: int, - world_size: int, - server_key_path: str = "", - server_cert_path: str = "", - client_cert_path: str = "", -) -> None: - """Run the Federated Learning server. + +class FederatedTracker(RabitTracker): + """Tracker for federated training. Parameters ---------- - port : int - The port to listen on. - world_size: int + n_workers : The number of federated workers. - server_key_path: str - Path to the server private key file. SSL is turned off if empty. - server_cert_path: str - Path to the server certificate file. SSL is turned off if empty. - client_cert_path: str - Path to the client certificate file. SSL is turned off if empty. + + port : + The port to listen on. + + secure : + Whether this is a secure instance. If True, then the following arguments for SSL + must be provided. + + server_key_path : + Path to the server private key file. + + server_cert_path : + Path to the server certificate file. + + client_cert_path : + Path to the client certificate file. + """ - if build_info()["USE_FEDERATED"]: - if not server_key_path or not server_cert_path or not client_cert_path: - _check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size)) - else: - _check_call( - _LIB.XGBRunFederatedServer( - port, - world_size, - c_str(server_key_path), - c_str(server_cert_path), - c_str(client_cert_path), - ) - ) - else: - raise XGBoostError( - "XGBoost needs to be built with the federated learning plugin " - "enabled in order to use this module" + + def __init__( # pylint: disable=R0913, W0231 + self, + n_workers: int, + port: int, + secure: bool, + server_key_path: str = "", + server_cert_path: str = "", + client_cert_path: str = "", + timeout: int = 300, + ) -> None: + handle = ctypes.c_void_p() + args = make_jcargs( + n_workers=n_workers, + port=port, + dmlc_communicator="federated", + federated_secure=secure, + server_key_path=server_key_path, + server_cert_path=server_cert_path, + client_cert_path=client_cert_path, + timeout=int(timeout), ) + _check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle))) + self.handle = handle + + +def run_federated_server( # pylint: disable=too-many-arguments + n_workers: int, + port: int, + server_key_path: Optional[str] = None, + server_cert_path: Optional[str] = None, + client_cert_path: Optional[str] = None, + timeout: int = 300, +) -> Dict[str, Any]: + """See :py:class:`~xgboost.federated.FederatedTracker` for more info.""" + args: Dict[str, Any] = {"n_workers": n_workers} + secure = all( + path is not None + for path in [server_key_path, server_cert_path, client_cert_path] + ) + tracker = FederatedTracker( + n_workers=n_workers, port=port, secure=secure, timeout=timeout + ) + tracker.start() + + thread = Thread(target=tracker.wait_for) + thread.daemon = True + thread.start() + args.update(tracker.worker_args()) + return args diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index c4713a9e49c7..560a3a8ed285 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -782,7 +782,10 @@ def __init__( def _more_tags(self) -> Dict[str, bool]: """Tags used for scikit-learn data validation.""" - return {"allow_nan": True, "no_validation": True} + tags = {"allow_nan": True, "no_validation": True} + if hasattr(self, "kwargs") and self.kwargs.get("updater") == "shotgun": + tags["non_deterministic"] = True + return tags def __sklearn_is_fitted__(self) -> bool: return hasattr(self, "_Booster") @@ -1439,6 +1442,11 @@ def __init__( ) -> None: super().__init__(objective=objective, **kwargs) + def _more_tags(self) -> Dict[str, bool]: + tags = super()._more_tags() + tags["multilabel"] = True + return tags + @_deprecate_positional_args def fit( self, @@ -1717,6 +1725,12 @@ def __init__( ) -> None: super().__init__(objective=objective, **kwargs) + def _more_tags(self) -> Dict[str, bool]: + tags = super()._more_tags() + tags["multioutput"] = True + tags["multioutput_only"] = False + return tags + @xgboost_model_doc( "scikit-learn API for XGBoost random forest regression.", diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index eb226611dd6c..8134ec7e7a72 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -95,6 +95,7 @@ deserialize_xgb_model, get_class_name, get_logger, + get_logger_level, serialize_booster, use_cuda, ) @@ -181,6 +182,8 @@ _INIT_BOOSTER_SAVE_PATH = "init_booster.json" +_LOG_TAG = "XGBoost-PySpark" + class _SparkXGBParams( HasFeaturesCol, @@ -344,15 +347,14 @@ def _gen_predict_params_dict(self) -> Dict[str, Any]: predict_params[param.name] = self.getOrDefault(param) return predict_params - def _validate_gpu_params(self) -> None: + def _validate_gpu_params( + self, spark_version: str, conf: SparkConf, is_local: bool = False + ) -> None: """Validate the gpu parameters and gpu configurations""" if self._run_on_gpu(): - ss = _get_spark_session() - sc = ss.sparkContext - - if _is_local(sc): - # Support GPU training in Spark local mode is just for debugging + if is_local: + # Supporting GPU training in Spark local mode is just for debugging # purposes, so it's okay for printing the below warning instead of # checking the real gpu numbers and raising the exception. get_logger(self.__class__.__name__).warning( @@ -361,33 +363,41 @@ def _validate_gpu_params(self) -> None: self.getOrDefault(self.num_workers), ) else: - executor_gpus = sc.getConf().get("spark.executor.resource.gpu.amount") + executor_gpus = conf.get("spark.executor.resource.gpu.amount") if executor_gpus is None: raise ValueError( "The `spark.executor.resource.gpu.amount` is required for training" " on GPU." ) - - if not ( - ss.version >= "3.4.0" - and _is_standalone_or_localcluster(sc.getConf()) + gpu_per_task = conf.get("spark.task.resource.gpu.amount") + if gpu_per_task is not None and float(gpu_per_task) > 1.0: + get_logger(self.__class__.__name__).warning( + "The configuration assigns %s GPUs to each Spark task, but each " + "XGBoost training task only utilizes 1 GPU, which will lead to " + "unnecessary GPU waste", + gpu_per_task, + ) + # For 3.5.1+, Spark supports task stage-level scheduling for + # Yarn/K8s/Standalone/Local cluster + # From 3.4.0 ~ 3.5.0, Spark only supports task stage-level scheduing for + # Standalone/Local cluster + # For spark below 3.4.0, Task stage-level scheduling is not supported. + # + # With stage-level scheduling, spark.task.resource.gpu.amount is not required + # to be set explicitly. Or else, spark.task.resource.gpu.amount is a must-have and + # must be set to 1.0 + if spark_version < "3.4.0" or ( + "3.4.0" <= spark_version < "3.5.1" + and not _is_standalone_or_localcluster(conf) ): - # We will enable stage-level scheduling in spark 3.4.0+ which doesn't - # require spark.task.resource.gpu.amount to be set explicitly - gpu_per_task = sc.getConf().get("spark.task.resource.gpu.amount") if gpu_per_task is not None: if float(gpu_per_task) < 1.0: raise ValueError( - "XGBoost doesn't support GPU fractional configurations. " - "Please set `spark.task.resource.gpu.amount=spark.executor" - ".resource.gpu.amount`" - ) - - if float(gpu_per_task) > 1.0: - get_logger(self.__class__.__name__).warning( - "%s GPUs for each Spark task is configured, but each " - "XGBoost training task uses only 1 GPU.", - gpu_per_task, + "XGBoost doesn't support GPU fractional configurations. Please set " + "`spark.task.resource.gpu.amount=spark.executor.resource.gpu." + "amount`. To enable GPU fractional configurations, you can try " + "standalone/localcluster with spark 3.4.0+ and" + "YARN/K8S with spark 3.5.1+" ) else: raise ValueError( @@ -472,7 +482,9 @@ def _validate_params(self) -> None: "`pyspark.ml.linalg.Vector` type." ) - self._validate_gpu_params() + ss = _get_spark_session() + sc = ss.sparkContext + self._validate_gpu_params(ss.version, sc.getConf(), _is_local(sc)) def _run_on_gpu(self) -> bool: """If train or transform on the gpu according to the parameters""" @@ -515,7 +527,8 @@ def _validate_and_convert_feature_col_as_array_col( (DoubleType, FloatType, LongType, IntegerType, ShortType), ): raise ValueError( - "If feature column is array type, its elements must be number type." + "If feature column is array type, its elements must be number type, " + f"got {features_col_datatype.elementType}." ) features_array_col = features_col.cast(ArrayType(FloatType())).alias(alias.data) elif isinstance(features_col_datatype, VectorUDT): @@ -922,10 +935,14 @@ def _skip_stage_level_scheduling(self, spark_version: str, conf: SparkConf) -> b ) return True - if not _is_standalone_or_localcluster(conf): + if ( + "3.4.0" <= spark_version < "3.5.1" + and not _is_standalone_or_localcluster(conf) + ): self.logger.info( - "Stage-level scheduling in xgboost requires spark standalone or " - "local-cluster mode" + "For %s, Stage-level scheduling in xgboost requires spark standalone " + "or local-cluster mode", + spark_version, ) return True @@ -977,7 +994,9 @@ def _try_stage_level_scheduling(self, rdd: RDD) -> RDD: """Try to enable stage-level scheduling""" ss = _get_spark_session() conf = ss.sparkContext.getConf() - if self._skip_stage_level_scheduling(ss.version, conf): + if _is_local(ss.sparkContext) or self._skip_stage_level_scheduling( + ss.version, conf + ): return rdd # executor_cores will not be None @@ -1034,6 +1053,8 @@ def _fit(self, dataset: DataFrame) -> "_SparkXGBModel": num_workers = self.getOrDefault(self.num_workers) + log_level = get_logger_level(_LOG_TAG) + def _train_booster( pandas_df_iter: Iterator[pd.DataFrame], ) -> Iterator[pd.DataFrame]: @@ -1047,7 +1068,8 @@ def _train_booster( dev_ordinal = None use_qdm = _can_use_qdm(booster_params.get("tree_method", None)) - + verbosity = booster_params.get("verbosity", 1) + msg = "Training on CPUs" if run_on_gpu: dev_ordinal = ( context.partitionId() if is_local else _get_gpu_id(context) @@ -1058,10 +1080,9 @@ def _train_booster( # Note: Checking `is_cudf_available` in spark worker side because # spark worker might has different python environment with driver side. use_qdm = use_qdm and is_cudf_available() - get_logger("XGBoost-PySpark").info( - "Leveraging %s to train with QDM: %s", - booster_params["device"], - "on" if use_qdm else "off", + msg = ( + f"Leveraging {booster_params['device']} to train with " + f"QDM: {'on' if use_qdm else 'off'}" ) if use_qdm and (booster_params.get("max_bin", None) is not None): @@ -1070,6 +1091,7 @@ def _train_booster( _rabit_args = {} if context.partitionId() == 0: _rabit_args = _get_rabit_args(context, num_workers) + get_logger(_LOG_TAG, log_level).info(msg) worker_message = { "rabit_msg": _rabit_args, @@ -1084,15 +1106,16 @@ def _train_booster( evals_result: Dict[str, Any] = {} with CommunicatorContext(context, **_rabit_args): - dtrain, dvalid = create_dmatrix_from_partitions( - pandas_df_iter, - feature_prop.features_cols_names, - dev_ordinal, - use_qdm, - dmatrix_kwargs, - enable_sparse_data_optim=feature_prop.enable_sparse_data_optim, - has_validation_col=feature_prop.has_validation_col, - ) + with xgboost.config_context(verbosity=verbosity): + dtrain, dvalid = create_dmatrix_from_partitions( + pandas_df_iter, + feature_prop.features_cols_names, + dev_ordinal, + use_qdm, + dmatrix_kwargs, + enable_sparse_data_optim=feature_prop.enable_sparse_data_optim, + has_validation_col=feature_prop.has_validation_col, + ) if dvalid is not None: dval = [(dtrain, "training"), (dvalid, "validation")] else: @@ -1127,7 +1150,7 @@ def _run_job() -> Tuple[str, str]: ret = rdd_with_resource.collect()[0] return ret[0], ret[1] - get_logger("XGBoost-PySpark").info( + get_logger(_LOG_TAG).info( "Running xgboost-%s on %s workers with" "\n\tbooster params: %s" "\n\ttrain_call_kwargs_params: %s" @@ -1139,7 +1162,7 @@ def _run_job() -> Tuple[str, str]: dmatrix_kwargs, ) (config, booster) = _run_job() - get_logger("XGBoost-PySpark").info("Finished xgboost training!") + get_logger(_LOG_TAG).info("Finished xgboost training!") result_xgb_model = self._convert_to_sklearn_model( bytearray(booster, "utf-8"), config @@ -1342,7 +1365,7 @@ def _run_on_gpu(self) -> bool: # User don't set gpu configurations, just use cpu if gpu_per_task is None: if use_gpu_by_params: - get_logger("XGBoost-PySpark").warning( + get_logger(_LOG_TAG).warning( "Do the prediction on the CPUs since " "no gpu configurations are set" ) @@ -1357,15 +1380,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame: # to avoid the `self` object to be pickled to remote. xgb_sklearn_model = self._xgb_sklearn_model - has_base_margin = False + base_margin_col = None if ( self.isDefined(self.base_margin_col) and self.getOrDefault(self.base_margin_col) != "" ): - has_base_margin = True base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias( alias.margin ) + has_base_margin = base_margin_col is not None features_col, feature_col_names = self._get_feature_col(dataset) enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim) @@ -1377,6 +1400,8 @@ def _transform(self, dataset: DataFrame) -> DataFrame: is_local = _is_local(_get_spark_session().sparkContext) run_on_gpu = self._run_on_gpu() + log_level = get_logger_level(_LOG_TAG) + @pandas_udf(schema) # type: ignore def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: assert xgb_sklearn_model is not None @@ -1413,7 +1438,8 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: else: msg = "CUDF or Cupy is unavailable, fallback the inference on the CPUs" - get_logger("XGBoost-PySpark").info(msg) + if context.partitionId() == 0: + get_logger(_LOG_TAG, log_level).info(msg) def to_gpu_if_possible(data: ArrayLike) -> ArrayLike: """Move the data to gpu if possible""" @@ -1447,6 +1473,7 @@ def to_gpu_if_possible(data: ArrayLike) -> ArrayLike: yield predict_func(model, X, base_margin) if has_base_margin: + assert base_margin_col is not None pred_col = predict_udf(struct(*features_col, base_margin_col)) else: pred_col = predict_udf(struct(*features_col)) diff --git a/python-package/xgboost/spark/data.py b/python-package/xgboost/spark/data.py index f9c12ba6628d..9c21f6ae8577 100644 --- a/python-package/xgboost/spark/data.py +++ b/python-package/xgboost/spark/data.py @@ -77,7 +77,7 @@ def __init__( self._data = data self._kwargs = kwargs - super().__init__() + super().__init__(release_data=True) def _fetch(self, data: Optional[Sequence[pd.DataFrame]]) -> Optional[pd.DataFrame]: if not data: diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 84333df53dd9..0a421031ecd4 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -8,13 +8,14 @@ import sys import uuid from threading import Thread -from typing import Any, Callable, Dict, Optional, Set, Type +from typing import Any, Callable, Dict, Optional, Set, Type, Union import pyspark from pyspark import BarrierTaskContext, SparkConf, SparkContext, SparkFiles, TaskContext from pyspark.sql.session import SparkSession -from xgboost import Booster, XGBModel, collective +from xgboost import Booster, XGBModel +from xgboost.collective import CommunicatorContext as CCtx from xgboost.tracker import RabitTracker @@ -42,35 +43,25 @@ def _get_default_params_from_func( return filtered_params_dict -class CommunicatorContext: - """A context controlling collective communicator initialization and finalization. - This isn't specificially necessary (note Part 3), but it is more understandable - coding-wise. - - """ +class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods + """Context with PySpark specific task ID.""" def __init__(self, context: BarrierTaskContext, **args: Any) -> None: - self.args = args - self.args["DMLC_TASK_ID"] = str(context.partitionId()) - - def __enter__(self) -> None: - collective.init(**self.args) - - def __exit__(self, *args: Any) -> None: - collective.finalize() + args["dmlc_task_id"] = str(context.partitionId()) + super().__init__(**args) def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: """Start Rabit tracker with n_workers""" - env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers} + args: Dict[str, Any] = {"n_workers": n_workers} host = _get_host_ip(context) - rabit_context = RabitTracker(host_ip=host, n_workers=n_workers) - env.update(rabit_context.worker_envs()) - rabit_context.start(n_workers) - thread = Thread(target=rabit_context.join) + tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task") + tracker.start() + thread = Thread(target=tracker.wait_for) thread.daemon = True thread.start() - return env + args.update(tracker.worker_args()) + return args def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: @@ -98,10 +89,15 @@ def _get_spark_session() -> SparkSession: return SparkSession.builder.getOrCreate() -def get_logger(name: str, level: str = "INFO") -> logging.Logger: +def get_logger(name: str, level: Optional[Union[str, int]] = None) -> logging.Logger: """Gets a logger by name, or creates and configures it for the first time.""" logger = logging.getLogger(name) - logger.setLevel(level) + if level is not None: + logger.setLevel(level) + else: + # Default to info if not set. + if logger.level == logging.NOTSET: + logger.setLevel(logging.INFO) # If the logger is configured, skip the configure if not logger.handlers and not logging.getLogger().handlers: handler = logging.StreamHandler(sys.stderr) @@ -113,6 +109,12 @@ def get_logger(name: str, level: str = "INFO") -> logging.Logger: return logger +def get_logger_level(name: str) -> Optional[int]: + """Get the logger level for the given log name""" + logger = logging.getLogger(name) + return None if logger.level == logging.NOTSET else logger.level + + def _get_max_num_concurrent_tasks(spark_context: SparkContext) -> int: """Gets the current max number of concurrent tasks.""" # pylint: disable=protected-access diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index f7d9510faea6..b85c0f3251b6 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -111,8 +111,6 @@ def no_sklearn() -> PytestSkip: def no_dask() -> PytestSkip: - if sys.platform.startswith("win"): - return {"reason": "Unsupported platform.", "condition": True} return no_mod("dask") @@ -193,6 +191,10 @@ def no_multiple(*args: Any) -> PytestSkip: return {"condition": condition, "reason": reason} +def skip_win() -> PytestSkip: + return {"reason": "Unsupported platform.", "condition": is_windows()} + + def skip_s390x() -> PytestSkip: condition = platform.machine() == "s390x" reason = "Known to fail on s390x" @@ -437,7 +439,7 @@ def make_categorical( index = rng.randint( low=0, high=n_samples - 1, size=int(n_samples * sparsity) ) - df.iloc[index, i] = np.NaN + df.iloc[index, i] = np.nan if is_categorical_dtype(df.dtypes[i]): assert n_categories == np.unique(df.dtypes[i].categories).size @@ -968,18 +970,18 @@ def run_worker(rabit_env: Dict[str, Union[str, int]]) -> None: exception_queue.put(e) tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size) - tracker.start(world_size) + tracker.start() workers = [] for _ in range(world_size): - worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),)) + worker = threading.Thread(target=run_worker, args=(tracker.worker_args(),)) workers.append(worker) worker.start() for worker in workers: worker.join() assert exception_queue.empty(), f"Worker failed: {exception_queue.get()}" - tracker.join() + tracker.wait_for() def column_split_feature_names( diff --git a/python-package/xgboost/testing/dask.py b/python-package/xgboost/testing/dask.py index f46803b29941..70e3dc219928 100644 --- a/python-package/xgboost/testing/dask.py +++ b/python-package/xgboost/testing/dask.py @@ -66,7 +66,7 @@ def check_uneven_nan(client: Client, tree_method: str, n_workers: int) -> None: X = pd.DataFrame({"a": range(10000), "b": range(10000, 0, -1)}) y = pd.Series([*[0] * 5000, *[1] * 5000]) - X["a"][:3000:1000] = np.NaN + X["a"][:3000:1000] = np.nan client.wait_for_workers(n_workers=n_workers) diff --git a/python-package/xgboost/testing/ranking.py b/python-package/xgboost/testing/ranking.py index a11eb3e030d1..72cf37aeb4d2 100644 --- a/python-package/xgboost/testing/ranking.py +++ b/python-package/xgboost/testing/ranking.py @@ -100,3 +100,21 @@ def run_ranking_categorical(device: str) -> None: scores = cross_val_score(ltr, X, y) for s in scores: assert s > 0.7 + + +def run_normalization(device: str) -> None: + """Test normalization.""" + X, y, qid, _ = tm.make_ltr(2048, 4, 64, 3) + ltr = xgb.XGBRanker(objective="rank:pairwise", n_estimators=4, device=device) + ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid]) + e0 = ltr.evals_result() + + ltr = xgb.XGBRanker( + objective="rank:pairwise", + n_estimators=4, + device=device, + lambdarank_normalization=False, + ) + ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid]) + e1 = ltr.evals_result() + assert e1["validation_0"]["ndcg@32"][-1] > e0["validation_0"]["ndcg@32"][-1] diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 606c63791c6d..d88b2564054b 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -1,64 +1,12 @@ -# pylint: disable=too-many-instance-attributes, too-many-arguments, too-many-branches -""" -This script is a variant of dmlc-core/dmlc_tracker/tracker.py, -which is a specialized version for xgboost tasks. -""" -import argparse -import logging -import socket -import struct -import sys -from threading import Thread -from typing import Dict, List, Optional, Set, Tuple, Union - -_RingMap = Dict[int, Tuple[int, int]] -_TreeMap = Dict[int, List[int]] - - -class ExSocket: - """ - Extension of socket to handle recv and send of special data - """ - - def __init__(self, sock: socket.socket) -> None: - self.sock = sock - - def recvall(self, nbytes: int) -> bytes: - """Receive number of bytes.""" - res = [] - nread = 0 - while nread < nbytes: - chunk = self.sock.recv(min(nbytes - nread, 1024)) - nread += len(chunk) - res.append(chunk) - return b"".join(res) - - def recvint(self) -> int: - """Receive an integer of 32 bytes""" - return struct.unpack("@i", self.recvall(4))[0] - - def sendint(self, value: int) -> None: - """Send an integer of 32 bytes""" - self.sock.sendall(struct.pack("@i", value)) - - def sendstr(self, value: str) -> None: - """Send a Python string""" - self.sendint(len(value)) - self.sock.sendall(value.encode()) - - def recvstr(self) -> str: - """Receive a Python string""" - slen = self.recvint() - return self.recvall(slen).decode() - - -# magic number used to verify existence of data -MAGIC_NUM = 0xFF99 +"""Tracker for XGBoost collective.""" +import ctypes +import json +import socket +from enum import IntEnum, unique +from typing import Dict, Optional, Union -def get_some_ip(host: str) -> str: - """Get ip from host""" - return socket.getaddrinfo(host, None)[0][4][0] +from .core import _LIB, _check_call, make_jcargs def get_family(addr: str) -> int: @@ -66,439 +14,95 @@ def get_family(addr: str) -> int: return socket.getaddrinfo(addr, None)[0][0] -class WorkerEntry: - """Hanlder to each worker.""" - - def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]): - worker = ExSocket(sock) - self.sock = worker - self.host = get_some_ip(s_addr[0]) - magic = worker.recvint() - assert magic == MAGIC_NUM, f"invalid magic number={magic} from {self.host}" - worker.sendint(MAGIC_NUM) - self.rank = worker.recvint() - self.world_size = worker.recvint() - self.task_id = worker.recvstr() - self.cmd = worker.recvstr() - self.wait_accept = 0 - self.port: Optional[int] = None +class RabitTracker: + """Tracker for the collective used in XGBoost, acting as a coordinator between + workers. - def print(self, use_logger: bool) -> None: - """Execute the print command from worker.""" - msg = self.sock.recvstr() - # On dask we use print to avoid setting global verbosity. - if use_logger: - logging.info(msg.strip()) - else: - print(msg.strip(), flush=True) + Parameters + .......... + sortby: - def decide_rank(self, job_map: Dict[str, int]) -> int: - """Get the rank of current entry.""" - if self.rank >= 0: - return self.rank - if self.task_id != "NULL" and self.task_id in job_map: - return job_map[self.task_id] - return -1 + How to sort the workers for rank assignment. The default is host, but users can + set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain + deterministic rank assignment. Available options are: + - host + - task - def assign_rank( - self, - rank: int, - wait_conn: Dict[int, "WorkerEntry"], - tree_map: _TreeMap, - parent_map: Dict[int, int], - ring_map: _RingMap, - ) -> List[int]: - """Assign the rank for current entry.""" - self.rank = rank - nnset = set(tree_map[rank]) - rprev, next_rank = ring_map[rank] - self.sock.sendint(rank) - # send parent rank - self.sock.sendint(parent_map[rank]) - # send world size - self.sock.sendint(len(tree_map)) - self.sock.sendint(len(nnset)) - # send the rprev and next link - for r in nnset: - self.sock.sendint(r) - # send prev link - if rprev not in (-1, rank): - nnset.add(rprev) - self.sock.sendint(rprev) - else: - self.sock.sendint(-1) - # send next link - if next_rank not in (-1, rank): - nnset.add(next_rank) - self.sock.sendint(next_rank) - else: - self.sock.sendint(-1) + timeout : - return self._get_remote(wait_conn, nnset) + Timeout for constructing the communication group and waiting for the tracker to + shutdown when it's instructed to, doesn't apply to communication when tracking + is running. - def _get_remote( - self, wait_conn: Dict[int, "WorkerEntry"], badset: Set[int] - ) -> List[int]: - while True: - conset = [] - for r in badset: - if r in wait_conn: - conset.append(r) - self.sock.sendint(len(conset)) - self.sock.sendint(len(badset) - len(conset)) - for r in conset: - self.sock.sendstr(wait_conn[r].host) - port = wait_conn[r].port - assert port is not None - # send port of this node to other workers so that they can call connect - self.sock.sendint(port) - self.sock.sendint(r) - nerr = self.sock.recvint() - if nerr != 0: - continue - self.port = self.sock.recvint() - rmset = [] - # all connection was successuly setup - for r in conset: - wait_conn[r].wait_accept -= 1 - if wait_conn[r].wait_accept == 0: - rmset.append(r) - for r in rmset: - wait_conn.pop(r, None) - self.wait_accept = len(badset) - len(conset) - return rmset + The timeout value should take the time of data loading and pre-processing into + account, due to potential lazy execution. + The :py:meth:`.wait_for` method has a different timeout parameter that can stop + the tracker even if the tracker is still being used. A value error is raised + when timeout is reached. -class RabitTracker: - """ - tracker for rabit """ - def __init__( + @unique + class _SortBy(IntEnum): + HOST = 0 + TASK = 1 + + def __init__( # pylint: disable=too-many-arguments self, - host_ip: str, n_workers: int, + host_ip: Optional[str], port: int = 0, - use_logger: bool = False, sortby: str = "host", + timeout: int = 0, ) -> None: - """A Python implementation of RABIT tracker. - - Parameters - .......... - use_logger: - Use logging.info for tracker print command. When set to False, Python print - function is used instead. - - sortby: - How to sort the workers for rank assignment. The default is host, but users - can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain - deterministic rank assignment. Available options are: - - host - - task - """ - sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM) - sock.bind((host_ip, port)) - self.port = sock.getsockname()[1] - sock.listen(256) - self.sock = sock - self.host_ip = host_ip - self.thread: Optional[Thread] = None - self.n_workers = n_workers - self._use_logger = use_logger - self._sortby = sortby - logging.info("start listen on %s:%d", host_ip, self.port) + handle = ctypes.c_void_p() + if sortby not in ("host", "task"): + raise ValueError("Expecting either 'host' or 'task' for sortby.") + if host_ip is not None: + get_family(host_ip) # use python socket to stop early for invalid address + args = make_jcargs( + host=host_ip, + n_workers=n_workers, + port=port, + dmlc_communicator="rabit", + sortby=self._SortBy.HOST if sortby == "host" else self._SortBy.TASK, + timeout=int(timeout), + ) + _check_call(_LIB.XGTrackerCreate(args, ctypes.byref(handle))) + self.handle = handle + + def free(self) -> None: + """Internal function for testing.""" + if hasattr(self, "handle"): + handle = self.handle + del self.handle + _check_call(_LIB.XGTrackerFree(handle)) def __del__(self) -> None: - if hasattr(self, "sock"): - self.sock.close() - - @staticmethod - def _get_neighbor(rank: int, n_workers: int) -> List[int]: - rank = rank + 1 - ret = [] - if rank > 1: - ret.append(rank // 2 - 1) - if rank * 2 - 1 < n_workers: - ret.append(rank * 2 - 1) - if rank * 2 < n_workers: - ret.append(rank * 2) - return ret - - def worker_envs(self) -> Dict[str, Union[str, int]]: - """ - get environment variables for workers - can be passed in as args or envs - """ - return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port} + self.free() - def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: - tree_map: _TreeMap = {} - parent_map: Dict[int, int] = {} - for r in range(n_workers): - tree_map[r] = self._get_neighbor(r, n_workers) - parent_map[r] = (r + 1) // 2 - 1 - return tree_map, parent_map + def start(self) -> None: + """Start the tracker. Once started, the client still need to call the + :py:meth:`wait_for` method in order to wait for it to finish (think of it as a + thread). - def find_share_ring( - self, tree_map: _TreeMap, parent_map: Dict[int, int], rank: int - ) -> List[int]: - """ - get a ring structure that tends to share nodes with the tree - return a list starting from rank """ - nset = set(tree_map[rank]) - cset = nset - {parent_map[rank]} - if not cset: - return [rank] - rlst = [rank] - cnt = 0 - for v in cset: - vlst = self.find_share_ring(tree_map, parent_map, v) - cnt += 1 - if cnt == len(cset): - vlst.reverse() - rlst += vlst - return rlst + _check_call(_LIB.XGTrackerRun(self.handle, make_jcargs())) - def get_ring(self, tree_map: _TreeMap, parent_map: Dict[int, int]) -> _RingMap: - """ - get a ring connection used to recover local data - """ - assert parent_map[0] == -1 - rlst = self.find_share_ring(tree_map, parent_map, 0) - assert len(rlst) == len(tree_map) - ring_map: _RingMap = {} - n_workers = len(tree_map) - for r in range(n_workers): - rprev = (r + n_workers - 1) % n_workers - rnext = (r + 1) % n_workers - ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) - return ring_map + def wait_for(self, timeout: Optional[int] = None) -> None: + """Wait for the tracker to finish all the work and shutdown. When timeout is + reached, a value error is raised. By default we don't have timeout since we + don't know how long it takes for the model to finish training. - def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingMap]: """ - get the link map, this is a bit hacky, call for better algorithm - to place similar nodes together - """ - tree_map, parent_map = self._get_tree(n_workers) - ring_map = self.get_ring(tree_map, parent_map) - rmap = {0: 0} - k = 0 - for i in range(n_workers - 1): - k = ring_map[k][1] - rmap[k] = i + 1 - - ring_map_: _RingMap = {} - tree_map_: _TreeMap = {} - parent_map_: Dict[int, int] = {} - for k, v in ring_map.items(): - ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]]) - for k, tree_nodes in tree_map.items(): - tree_map_[rmap[k]] = [rmap[x] for x in tree_nodes] - for k, parent in parent_map.items(): - if k != 0: - parent_map_[rmap[k]] = rmap[parent] - else: - parent_map_[rmap[k]] = -1 - return tree_map_, parent_map_, ring_map_ - - def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]: - if self._sortby == "host": - pending.sort(key=lambda s: s.host) - elif self._sortby == "task": - pending.sort(key=lambda s: s.task_id) - return pending - - def accept_workers(self, n_workers: int) -> None: - """Wait for all workers to connect to the tracker.""" - - # set of nodes that finishes the job - shutdown: Dict[int, WorkerEntry] = {} - # set of nodes that is waiting for connections - wait_conn: Dict[int, WorkerEntry] = {} - # maps job id to rank - job_map: Dict[str, int] = {} - # list of workers that is pending to be assigned rank - pending: List[WorkerEntry] = [] - # lazy initialize tree_map - tree_map = None - - while len(shutdown) != n_workers: - fd, s_addr = self.sock.accept() - s = WorkerEntry(fd, s_addr) - if s.cmd == "print": - s.print(self._use_logger) - continue - if s.cmd == "shutdown": - assert s.rank >= 0 and s.rank not in shutdown - assert s.rank not in wait_conn - shutdown[s.rank] = s - logging.debug("Received %s signal from %d", s.cmd, s.rank) - continue - assert s.cmd == "start" - # lazily initialize the workers - if tree_map is None: - assert s.cmd == "start" - if s.world_size > 0: - n_workers = s.world_size - tree_map, parent_map, ring_map = self.get_link_map(n_workers) - # set of nodes that is pending for getting up - todo_nodes = list(range(n_workers)) - else: - assert s.world_size in (-1, n_workers) - if s.cmd == "recover": - assert s.rank >= 0 - - rank = s.decide_rank(job_map) - # batch assignment of ranks - if rank == -1: - assert todo_nodes - pending.append(s) - if len(pending) == len(todo_nodes): - pending = self._sort_pending(pending) - for s in pending: - rank = todo_nodes.pop(0) - if s.task_id != "NULL": - job_map[s.task_id] = rank - s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) - if s.wait_accept > 0: - wait_conn[rank] = s - logging.debug( - "Received %s signal from %s; assign rank %d", - s.cmd, - s.host, - s.rank, - ) - if not todo_nodes: - logging.info("@tracker All of %d nodes getting started", n_workers) - else: - s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) - logging.debug("Received %s signal from %d", s.cmd, s.rank) - if s.wait_accept > 0: - wait_conn[rank] = s - logging.info("@tracker All nodes finishes job") - - def start(self, n_workers: int) -> None: - """Strat the tracker, it will wait for `n_workers` to connect.""" - - def run() -> None: - self.accept_workers(n_workers) - - self.thread = Thread(target=run, args=(), daemon=True) - self.thread.start() - - def join(self) -> None: - """Wait for the tracker to finish.""" - while self.thread is not None and self.thread.is_alive(): - self.thread.join(100) - - def alive(self) -> bool: - """Wether the tracker thread is alive""" - return self.thread is not None and self.thread.is_alive() - - -def get_host_ip(host_ip: Optional[str] = None) -> str: - """Get the IP address of current host. If `host_ip` is not none then it will be - returned as it's - - """ - if host_ip is None or host_ip == "auto": - host_ip = "ip" - - if host_ip == "dns": - host_ip = socket.getfqdn() - elif host_ip == "ip": - from socket import gaierror - - try: - host_ip = socket.gethostbyname(socket.getfqdn()) - except gaierror: - logging.debug( - "gethostbyname(socket.getfqdn()) failed... trying on hostname()" - ) - host_ip = socket.gethostbyname(socket.gethostname()) - if host_ip.startswith("127."): - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # doesn't have to be reachable - s.connect(("10.255.255.255", 1)) - host_ip = s.getsockname()[0] - - assert host_ip is not None - return host_ip - - -def start_rabit_tracker(args: argparse.Namespace) -> None: - """Standalone function to start rabit tracker. - - Parameters - ---------- - args: arguments to start the rabit tracker. - """ - envs = {"DMLC_NUM_WORKER": args.num_workers, "DMLC_NUM_SERVER": args.num_servers} - rabit = RabitTracker( - host_ip=get_host_ip(args.host_ip), n_workers=args.num_workers, use_logger=True - ) - envs.update(rabit.worker_envs()) - rabit.start(args.num_workers) - sys.stdout.write("DMLC_TRACKER_ENV_START\n") - # simply write configuration to stdout - for k, v in envs.items(): - sys.stdout.write(f"{k}={v}\n") - sys.stdout.write("DMLC_TRACKER_ENV_END\n") - sys.stdout.flush() - rabit.join() - - -def main() -> None: - """Main function if tracker is executed in standalone mode.""" - parser = argparse.ArgumentParser(description="Rabit Tracker start.") - parser.add_argument( - "--num-workers", - required=True, - type=int, - help="Number of worker process to be launched.", - ) - parser.add_argument( - "--num-servers", - default=0, - type=int, - help="Number of server process to be launched. Only used in PS jobs.", - ) - parser.add_argument( - "--host-ip", - default=None, - type=str, - help=( - "Host IP addressed, this is only needed " - + "if the host IP cannot be automatically guessed." - ), - ) - parser.add_argument( - "--log-level", - default="INFO", - type=str, - choices=["INFO", "DEBUG"], - help="Logging level of the logger.", - ) - args = parser.parse_args() - - fmt = "%(asctime)s %(levelname)s %(message)s" - if args.log_level == "INFO": - level = logging.INFO - elif args.log_level == "DEBUG": - level = logging.DEBUG - else: - raise RuntimeError(f"Unknown logging level {args.log_level}") - - logging.basicConfig(format=fmt, level=level) - - if args.num_servers == 0: - start_rabit_tracker(args) - else: - raise RuntimeError("Do not yet support start ps tracker in standalone mode.") - - -if __name__ == "__main__": - main() + _check_call(_LIB.XGTrackerWaitFor(self.handle, make_jcargs(timeout=timeout))) + + def worker_args(self) -> Dict[str, Union[str, int]]: + """Get arguments for workers.""" + c_env = ctypes.c_char_p() + _check_call(_LIB.XGTrackerWorkerArgs(self.handle, ctypes.byref(c_env))) + assert c_env.value is not None + env = json.loads(c_env.value) + return env diff --git a/rabit/CMakeLists.txt b/rabit/CMakeLists.txt deleted file mode 100644 index 4562f864f2da..000000000000 --- a/rabit/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -cmake_minimum_required(VERSION 3.18) - -find_package(Threads REQUIRED) - -set(RABIT_SOURCES - ${CMAKE_CURRENT_LIST_DIR}/src/allreduce_base.cc - ${CMAKE_CURRENT_LIST_DIR}/src/rabit_c_api.cc) - -if(RABIT_MOCK) - list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine_mock.cc) -else() - list(APPEND RABIT_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/engine.cc) -endif() - -set(RABIT_SOURCES ${RABIT_SOURCES} PARENT_SCOPE) diff --git a/rabit/LICENSE b/rabit/LICENSE deleted file mode 100644 index 2485f4eaa5a3..000000000000 --- a/rabit/LICENSE +++ /dev/null @@ -1,28 +0,0 @@ -Copyright (c) 2014 by Contributors -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* Neither the name of rabit nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - diff --git a/rabit/README.md b/rabit/README.md deleted file mode 100644 index 0be1b7015d18..000000000000 --- a/rabit/README.md +++ /dev/null @@ -1 +0,0 @@ -# This directory contains the CPU network module for XGBoost. The library originates from [RABIT](https://github.com/dmlc/rabit) \ No newline at end of file diff --git a/rabit/include/rabit/base.h b/rabit/include/rabit/base.h deleted file mode 100644 index ab3a285d1a0a..000000000000 --- a/rabit/include/rabit/base.h +++ /dev/null @@ -1,19 +0,0 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file base.h - * \brief Macros common to all headers - * - * \author Hyunsu Cho - */ - -#ifndef RABIT_BASE_H_ -#define RABIT_BASE_H_ - -#ifndef _CRT_SECURE_NO_WARNINGS -#define _CRT_SECURE_NO_WARNINGS -#endif // _CRT_SECURE_NO_WARNINGS -#ifndef _CRT_SECURE_NO_DEPRECATE -#define _CRT_SECURE_NO_DEPRECATE -#endif // _CRT_SECURE_NO_DEPRECATE - -#endif // RABIT_BASE_H_ diff --git a/rabit/include/rabit/c_api.h b/rabit/include/rabit/c_api.h deleted file mode 100644 index 6c9e798bea38..000000000000 --- a/rabit/include/rabit/c_api.h +++ /dev/null @@ -1,157 +0,0 @@ -/*! - * Copyright by Contributors - * \file c_api.h - * \author Tianqi Chen - * \brief a C style API of rabit. - */ -#ifndef RABIT_C_API_H_ -#define RABIT_C_API_H_ - -#ifdef __cplusplus -#define RABIT_EXTERN_C extern "C" -#include -#else -#define RABIT_EXTERN_C -#include -#endif // __cplusplus - -#if defined(_MSC_VER) || defined(_WIN32) -#define RABIT_DLL RABIT_EXTERN_C __declspec(dllexport) -#else -#define RABIT_DLL RABIT_EXTERN_C __attribute__ ((visibility ("default"))) -#endif // defined(_MSC_VER) || defined(_WIN32) - -/*! \brief rabit unsigned long type */ -typedef unsigned long rbt_ulong; // NOLINT(*) - -/*! - * \brief initialize the rabit module, - * call this once before using anything - * The additional arguments is not necessary. - * Usually rabit will detect settings - * from environment variables. - * \param argc number of arguments in argv - * \param argv the array of input arguments - * \return true if rabit is initialized successfully otherwise false - */ -RABIT_DLL bool RabitInit(int argc, char *argv[]); - -/*! - * \brief finalize the rabit engine, - * call this function after you finished all jobs. - * \return true if rabit is initialized successfully otherwise false - */ -RABIT_DLL int RabitFinalize(void); - -/*! - * \brief get rank of previous process in ring topology - * \return rank number of worker - * */ -RABIT_DLL int RabitGetRingPrevRank(void); - -/*! - * \brief get rank of current process - * \return rank number of worker - * */ -RABIT_DLL int RabitGetRank(void); - -/*! - * \brief get total number of process - * \return total world size - * */ -RABIT_DLL int RabitGetWorldSize(void); - -/*! - * \brief get rank of current process - * \return if rabit is distributed - * */ -RABIT_DLL int RabitIsDistributed(void); - -/*! - * \brief print the msg to the tracker, - * this function can be used to communicate the information of the progress to - * the user who monitors the tracker - * \param msg the message to be printed - */ -RABIT_DLL int RabitTrackerPrint(const char *msg); -/*! - * \brief get name of processor - * \param out_name hold output string - * \param out_len hold length of output string - * \param max_len maximum buffer length of input - */ -RABIT_DLL void RabitGetProcessorName(char *out_name, - rbt_ulong *out_len, - rbt_ulong max_len); -/*! - * \brief broadcast an memory region to all others from root - * - * Example: int a = 1; Broadcast(&a, sizeof(a), root); - * \param sendrecv_data the pointer to send or receive buffer, - * \param size the size of the data - * \param root the root of process - */ -RABIT_DLL int RabitBroadcast(void *sendrecv_data, rbt_ulong size, int root); - -/*! - * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param beginIndex beginning of the current slice in sendrecvbuf of type enum_dtype - * \param size_node_slice size of the current node slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - * \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -RABIT_DLL int RabitAllgather(void *sendrecvbuf, size_t total_size, - size_t beginIndex, size_t size_node_slice, - size_t size_prev_slice, int enum_dtype); - -/*! - * \brief perform in-place allreduce, on sendrecvbuf - * this function is NOT thread-safe - * - * Example Usage: the following code gives sum of the result - * vector data(10); - * ... - * Allreduce(&data[0], data.size()); - * ... - * \param sendrecvbuf buffer for both sending and receiving data - * \param count number of elements to be reduced - * \param enum_dtype the enumeration of data type, see rabit::engine::mpi::DataType in engine.h of rabit include - * \param enum_op the enumeration of operation type, see rabit::engine::mpi::OpType in engine.h of rabit - * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to passed into the lazy preprocessing function - */ -RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype, - int enum_op, void (*prepare_fun)(void *arg), - void *prepare_arg); - -/*! - * \return version number of current stored model, - * which means how many calls to CheckPoint we made so far - * \return rabit version number - */ -RABIT_DLL int RabitVersionNumber(void); - - -/*! - * \brief a Dummy function, - * used to cause force link of C API into the DLL. - * \code - * \/\/force link rabit C API library. - * static int must_link_rabit_ = RabitLinkTag(); - * \endcode - * \return a dummy integer. - */ -RABIT_DLL int RabitLinkTag(void); - -#endif // RABIT_C_API_H_ diff --git a/rabit/include/rabit/internal/engine.h b/rabit/include/rabit/internal/engine.h deleted file mode 100644 index aa074fb39b97..000000000000 --- a/rabit/include/rabit/internal/engine.h +++ /dev/null @@ -1,197 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file engine.h - * \brief This file defines the core interface of rabit library - * \author Tianqi Chen, Nacho, Tianyi - */ -#ifndef RABIT_INTERNAL_ENGINE_H_ -#define RABIT_INTERNAL_ENGINE_H_ -#include -#include "rabit/serializable.h" - -namespace MPI { // NOLINT -/*! \brief MPI data type just to be compatible with MPI reduce function*/ -class Datatype; -} - -/*! \brief namespace of rabit */ -namespace rabit { -/*! \brief core interface of the engine */ -namespace engine { -/*! \brief interface of core Allreduce engine */ -class IEngine { - public: - /*! - * \brief Preprocessing function, that is called before AllReduce, - * used to prepare the data used by AllReduce - * \param arg additional possible argument used to invoke the preprocessor - */ - typedef void (PreprocFunction) (void *arg); // NOLINT - /*! - * \brief reduce function, the same form of MPI reduce function is used, - * to be compatible with MPI interface - * In all the functions, the memory is ensured to aligned to 64-bit - * which means it is OK to cast src,dst to double* int* etc - * \param src pointer to source space - * \param dst pointer to destination reduction - * \param count total number of elements to be reduced (note this is total number of elements instead of bytes) - * the definition of the reduce function should be type aware - * \param dtype the data type object, to be compatible with MPI reduce - */ - typedef void (ReduceFunction) (const void *src, // NOLINT - void *dst, int count, - const MPI::Datatype &dtype); - /*! \brief virtual destructor */ - ~IEngine() = default; - /*! - * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - */ - virtual void Allgather(void *sendrecvbuf, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice) = 0; - /*! - * \brief performs in-place Allreduce, on sendrecvbuf - * this function is NOT thread-safe - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the number of bytes the type has - * \param count number of elements to be reduced - * \param reducer reduce function - * \param prepare_func Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to pass into the lazy preprocessing function - */ - virtual void Allreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer, - PreprocFunction prepare_fun = nullptr, - void *prepare_arg = nullptr) = 0; - /*! - * \brief broadcasts data from root to every other node - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - */ - virtual void Broadcast(void *sendrecvbuf_, size_t size, int root) = 0; - /*! - * deprecated - */ - virtual int LoadCheckPoint() = 0; - /*! - * \brief Increase internal version number. Deprecated. - */ - virtual void CheckPoint() = 0; - /*! - * \return version number of the current stored model, - * which means how many calls to CheckPoint we made so far - * \sa LoadCheckPoint, CheckPoint - */ - virtual int VersionNumber() const = 0; - /*! \brief gets rank of previous node in ring topology */ - virtual int GetRingPrevRank() const = 0; - /*! \brief gets rank of current node */ - virtual int GetRank() const = 0; - /*! \brief gets total number of nodes */ - virtual int GetWorldSize() const = 0; - /*! \brief whether we run in distribted mode */ - virtual bool IsDistributed() const = 0; - /*! \brief gets the host name of the current node */ - virtual std::string GetHost() const = 0; - /*! - * \brief prints the msg in the tracker, - * this function can be used to communicate progress information to - * the user who monitors the tracker - * \param msg message to be printed in the tracker - */ - virtual void TrackerPrint(const std::string &msg) = 0; -}; - -/*! \brief initializes the engine module */ -bool Init(int argc, char *argv[]); -/*! \brief finalizes the engine module */ -bool Finalize(); -/*! \brief singleton method to get engine */ -IEngine *GetEngine(); - -/*! \brief namespace that contains stubs to be compatible with MPI */ -namespace mpi { -/*!\brief enum of all operators */ -enum OpType { - kMax = 0, - kMin = 1, - kSum = 2, - kBitwiseAND = 3, - kBitwiseOR = 4, - kBitwiseXOR = 5, -}; -/*!\brief enum of supported data types */ -enum DataType { - kChar = 0, - kUChar = 1, - kInt = 2, - kUInt = 3, - kLong = 4, - kULong = 5, - kFloat = 6, - kDouble = 7, - kLongLong = 8, - kULongLong = 9 -}; -} // namespace mpi -/*! - * \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - */ -void Allgather(void* sendrecvbuf, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice); -/*! - * \brief perform in-place Allreduce, on sendrecvbuf - * this is an internal function used by rabit to be able to compile with MPI - * do not use this function directly - * \param sendrecvbuf buffer for both sending and receiving data - * \param type_nbytes the number of bytes the type has - * \param count number of elements to be reduced - * \param reducer reduce function - * \param dtype the data type - * \param op the reduce operator type - * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to pass into the lazy preprocessing function. - */ -void Allreduce_(void *sendrecvbuf, // NOLINT - size_t type_nbytes, - size_t count, - IEngine::ReduceFunction red, - mpi::DataType dtype, - mpi::OpType op, - IEngine::PreprocFunction prepare_fun = nullptr, - void *prepare_arg = nullptr); -} // namespace engine -} // namespace rabit -#endif // RABIT_INTERNAL_ENGINE_H_ diff --git a/rabit/include/rabit/internal/io.h b/rabit/include/rabit/internal/io.h deleted file mode 100644 index d5d0fee4d79c..000000000000 --- a/rabit/include/rabit/internal/io.h +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2014-2023, XGBoost Contributors - * \file io.h - * \brief utilities with different serializable implementations - * \author Tianqi Chen - */ -#ifndef RABIT_INTERNAL_IO_H_ -#define RABIT_INTERNAL_IO_H_ - -#include -#include // for size_t -#include -#include // for memcpy -#include -#include -#include -#include - -#include "dmlc/io.h" -#include "xgboost/logging.h" - -namespace rabit::utils { -/*! \brief re-use definition of dmlc::SeekStream */ -using SeekStream = dmlc::SeekStream; -/** - * @brief Fixed size memory buffer as a stream. - */ -struct MemoryFixSizeBuffer : public SeekStream { - public: - // similar to SEEK_END in libc - static std::size_t constexpr kSeekEnd = std::numeric_limits::max(); - - public: - /** - * @brief Ctor - * - * @param p_buffer Pointer to the source buffer with size `buffer_size`. - * @param buffer_size Size of the source buffer - */ - MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size) - : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) {} - ~MemoryFixSizeBuffer() override = default; - - std::size_t Read(void *ptr, std::size_t size) override { - std::size_t nread = std::min(buffer_size_ - curr_ptr_, size); - if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); - curr_ptr_ += nread; - return nread; - } - void Write(const void *ptr, std::size_t size) override { - if (size == 0) return; - CHECK_LE(curr_ptr_ + size, buffer_size_); - std::memcpy(p_buffer_ + curr_ptr_, ptr, size); - curr_ptr_ += size; - } - void Seek(std::size_t pos) override { - if (pos == kSeekEnd) { - curr_ptr_ = buffer_size_; - } else { - curr_ptr_ = static_cast(pos); - } - } - /** - * @brief Current position in the buffer (stream). - */ - std::size_t Tell() override { return curr_ptr_; } - [[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; } - - protected: - /*! \brief in memory buffer */ - char *p_buffer_{nullptr}; - /*! \brief current pointer */ - std::size_t buffer_size_{0}; - /*! \brief current pointer */ - std::size_t curr_ptr_{0}; -}; - -/*! \brief a in memory buffer that can be read and write as stream interface */ -struct MemoryBufferStream : public SeekStream { - public: - explicit MemoryBufferStream(std::string *p_buffer) - : p_buffer_(p_buffer) { - curr_ptr_ = 0; - } - ~MemoryBufferStream() override = default; - size_t Read(void *ptr, size_t size) override { - CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length"; - size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); - if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); - curr_ptr_ += nread; - return nread; - } - void Write(const void *ptr, size_t size) override { - if (size == 0) return; - if (curr_ptr_ + size > p_buffer_->length()) { - p_buffer_->resize(curr_ptr_+size); - } - std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); - curr_ptr_ += size; - } - void Seek(size_t pos) override { - curr_ptr_ = static_cast(pos); - } - size_t Tell() override { - return curr_ptr_; - } - virtual bool AtEnd() const { - return curr_ptr_ == p_buffer_->length(); - } - - private: - /*! \brief in memory buffer */ - std::string *p_buffer_; - /*! \brief current pointer */ - size_t curr_ptr_; -}; // class MemoryBufferStream -} // namespace rabit::utils -#endif // RABIT_INTERNAL_IO_H_ diff --git a/rabit/include/rabit/internal/rabit-inl.h b/rabit/include/rabit/internal/rabit-inl.h deleted file mode 100644 index 49b086320d25..000000000000 --- a/rabit/include/rabit/internal/rabit-inl.h +++ /dev/null @@ -1,234 +0,0 @@ -/*! - * Copyright (c) 2014-2019 by Contributors - * \file rabit-inl.h - * \brief implementation of inline template function for rabit interface - * - * \author Tianqi Chen - */ -#ifndef RABIT_INTERNAL_RABIT_INL_H_ -#define RABIT_INTERNAL_RABIT_INL_H_ -// use engine for implementation -#include -#include -#include "rabit/internal/io.h" -#include "rabit/internal/utils.h" -#include "rabit/rabit.h" - -namespace rabit { -namespace engine { -namespace mpi { -// template function to translate type to enum indicator -template -inline DataType GetType(); -template<> -inline DataType GetType() { - return kChar; -} -template<> -inline DataType GetType() { - return kUChar; -} -template<> -inline DataType GetType() { - return kInt; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kUInt; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kLong; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kULong; -} -template<> -inline DataType GetType() { - return kFloat; -} -template<> -inline DataType GetType() { - return kDouble; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kLongLong; -} -template<> -inline DataType GetType() { // NOLINT(*) - return kULongLong; -} -} // namespace mpi -} // namespace engine - -namespace op { -struct Max { - static const engine::mpi::OpType kType = engine::mpi::kMax; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - if (dst < src) dst = src; - } -}; -struct Min { - static const engine::mpi::OpType kType = engine::mpi::kMin; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - if (dst > src) dst = src; - } -}; -struct Sum { - static const engine::mpi::OpType kType = engine::mpi::kSum; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - dst += src; - } -}; -struct BitAND { - static const engine::mpi::OpType kType = engine::mpi::kBitwiseAND; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - dst &= src; - } -}; -struct BitOR { - static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - dst |= src; - } -}; -struct BitXOR { - static const engine::mpi::OpType kType = engine::mpi::kBitwiseXOR; - template - inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*) - dst ^= src; - } -}; -template -inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &) { - const DType *src = static_cast(src_); - DType *dst = (DType *)dst_; // NOLINT(*) - for (int i = 0; i < len; i++) { - OP::Reduce(dst[i], src[i]); - } -} -} // namespace op - -// initialize the rabit engine -inline bool Init(int argc, char *argv[]) { - return engine::Init(argc, argv); -} -// finalize the rabit engine -inline bool Finalize() { - return engine::Finalize(); -} -// get the rank of the previous worker in ring topology -inline int GetRingPrevRank() { - return engine::GetEngine()->GetRingPrevRank(); -} -// get the rank of current process -inline int GetRank() { - return engine::GetEngine()->GetRank(); -} -// the the size of the world -inline int GetWorldSize() { - return engine::GetEngine()->GetWorldSize(); -} -// whether rabit is distributed -inline bool IsDistributed() { - return engine::GetEngine()->IsDistributed(); -} -// get the name of current processor -inline std::string GetProcessorName() { - return engine::GetEngine()->GetHost(); -} -// broadcast data to all other nodes from root -inline void Broadcast(void *sendrecv_data, size_t size, int root) { - engine::GetEngine()->Broadcast(sendrecv_data, size, root); -} -template -inline void Broadcast(std::vector *sendrecv_data, int root) { - size_t size = sendrecv_data->size(); - Broadcast(&size, sizeof(size), root); - if (sendrecv_data->size() != size) { - sendrecv_data->resize(size); - } - if (size != 0) { - Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root); - } -} -inline void Broadcast(std::string *sendrecv_data, int root) { - size_t size = sendrecv_data->length(); - Broadcast(&size, sizeof(size), root); - if (sendrecv_data->length() != size) { - sendrecv_data->resize(size); - } - if (size != 0) { - Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root); - } -} - -// perform inplace Allreduce -template -inline void Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *arg), - void *prepare_arg) { - engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer, - engine::mpi::GetType(), OP::kType, prepare_fun, prepare_arg); -} - -// C++11 support for lambda prepare function -#if DMLC_USE_CXX11 -inline void InvokeLambda(void *fun) { - (*static_cast*>(fun))(); -} -template -inline void Allreduce(DType *sendrecvbuf, size_t count, - std::function prepare_fun) { - engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer, - engine::mpi::GetType(), OP::kType, InvokeLambda, &prepare_fun); -} - -// Performs inplace Allgather -template -inline void Allgather(DType *sendrecvbuf, - size_t totalSize, - size_t beginIndex, - size_t sizeNodeSlice, - size_t sizePrevSlice) { - engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType), - (beginIndex + sizeNodeSlice) * sizeof(DType), - sizePrevSlice * sizeof(DType)); -} -#endif // C++11 - -// print message to the tracker -inline void TrackerPrint(const std::string &msg) { - engine::GetEngine()->TrackerPrint(msg); -} -#ifndef RABIT_STRICT_CXX98_ -inline void TrackerPrintf(const char *fmt, ...) { - const int kPrintBuffer = 1 << 10; - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - msg.resize(strlen(msg.c_str())); - TrackerPrint(msg); -} - -#endif // RABIT_STRICT_CXX98_ - -// deprecated, planned for removal after checkpoing from JVM package is removed. -inline int LoadCheckPoint() { return engine::GetEngine()->LoadCheckPoint(); } -// deprecated, increase internal version number -inline void CheckPoint() { engine::GetEngine()->CheckPoint(); } -// return the version number of currently stored model -inline int VersionNumber() { - return engine::GetEngine()->VersionNumber(); -} -} // namespace rabit -#endif // RABIT_INTERNAL_RABIT_INL_H_ diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index 89e3244822df..3701146d4577 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023, XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file socket.h * \author Tianqi Chen */ @@ -95,11 +95,32 @@ int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) template std::enable_if_t, xgboost::collective::Result> PollError(E const& revents) { if ((revents & POLLERR) != 0) { - return xgboost::system::FailWithCode("Poll error condition."); + auto err = errno; + auto str = strerror(err); + return xgboost::system::FailWithCode(std::string{"Poll error condition:"} + std::string{str} + + " code:" + std::to_string(err)); } if ((revents & POLLNVAL) != 0) { return xgboost::system::FailWithCode("Invalid polling request."); } + if ((revents & POLLHUP) != 0) { + // Excerpt from the Linux manual: + // + // Note that when reading from a channel such as a pipe or a stream socket, this event + // merely indicates that the peer closed its end of the channel.Subsequent reads from + // the channel will return 0 (end of file) only after all outstanding data in the + // channel has been consumed. + // + // We don't usually have a barrier for exiting workers, it's normal to have one end + // exit while the other still reading data. + return xgboost::collective::Success(); + } +#if defined(POLLRDHUP) + // Linux only flag + if ((revents & POLLRDHUP) != 0) { + return xgboost::system::FailWithCode("Poll hung up on the other end."); + } +#endif // defined(POLLRDHUP) return xgboost::collective::Success(); } @@ -179,9 +200,11 @@ struct PollHelper { } std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout); if (ret == 0) { - return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out)); + return xgboost::collective::Fail( + "Poll timeout:" + std::to_string(timeout.count()) + " seconds.", + std::make_error_code(std::errc::timed_out)); } else if (ret < 0) { - return xgboost::system::FailWithCode("Poll failed."); + return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size())); } for (auto& pfd : fdset) { @@ -191,12 +214,7 @@ struct PollHelper { } auto revents = pfd.revents & pfd.events; - if (!revents) { - // FIXME(jiamingy): remove this once rabit is replaced. - fds.erase(pfd.fd); - } else { - fds[pfd.fd].events = revents; - } + fds[pfd.fd].events = revents; } return xgboost::collective::Success(); } diff --git a/rabit/include/rabit/internal/utils.h b/rabit/include/rabit/internal/utils.h deleted file mode 100644 index c1739ce7967b..000000000000 --- a/rabit/include/rabit/internal/utils.h +++ /dev/null @@ -1,146 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file utils.h - * \brief simple utils to support the code - * \author Tianqi Chen - */ -#ifndef RABIT_INTERNAL_UTILS_H_ -#define RABIT_INTERNAL_UTILS_H_ - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "dmlc/io.h" -#include "xgboost/logging.h" - -#if !defined(__GNUC__) || defined(__FreeBSD__) -#define fopen64 std::fopen -#endif // !defined(__GNUC__) || defined(__FreeBSD__) - -#ifndef _MSC_VER - -#ifdef _FILE_OFFSET_BITS -#if _FILE_OFFSET_BITS == 32 -#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit") -#endif // _FILE_OFFSET_BITS == 32 -#endif // _FILE_OFFSET_BITS - -#ifdef __APPLE__ -#define off64_t off_t -#define fopen64 std::fopen -#endif // __APPLE__ - -extern "C" { -#include -} -#endif // _MSC_VER - -#include - -namespace rabit { -/*! \brief namespace for helper utils of the project */ -namespace utils { - -/*! \brief error message buffer length */ -const int kPrintBuffer = 1 << 12; - -/* \brief Case-insensitive string comparison */ -inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) { -#ifdef _MSC_VER - return _stricmp(s1, s2); -#else // _MSC_VER - return strcasecmp(s1, s2); -#endif // _MSC_VER -} - -/* \brief parse config string too bool*/ -inline bool StringToBool(const char* s) { - return CompareStringsCaseInsensitive(s, "true") == 0 || atoi(s) != 0; -} - -/*! \brief printf, prints messages to the console */ -inline void Printf(const char *fmt, ...) { - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - LOG(CONSOLE) << msg; -} - -/*! \brief assert a condition is true, use this to handle debug information */ -inline void Assert(bool exp, const char *fmt, ...) { - if (!exp) { - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - LOG(FATAL) << msg; - } -} - -/*!\brief same as assert, but this is intended to be used as a message for users */ -inline void Check(bool exp, const char *fmt, ...) { - if (!exp) { - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - LOG(FATAL) << msg; - } -} - -/*! \brief report error message, same as check */ -inline void Error(const char *fmt, ...) { - { - std::string msg(kPrintBuffer, '\0'); - va_list args; - va_start(args, fmt); - vsnprintf(&msg[0], kPrintBuffer, fmt, args); - va_end(args); - LOG(FATAL) << msg; - } -} -} // namespace utils - -// Can not use std::min on Windows with msvc due to: -// error C2589: '(': illegal token on right side of '::' -template -auto Min(T const& l, T const& r) { - return l < r ? l : r; -} -// same with Min -template -auto Max(T const& l, T const& r) { - return l > r ? l : r; -} - -// easy utils that can be directly accessed in xgboost -/*! \brief get the beginning address of a vector */ -template -inline T *BeginPtr(std::vector &vec) { // NOLINT(*) - if (vec.size() == 0) { - return nullptr; - } else { - return &vec[0]; - } -} -inline char* BeginPtr(std::string &str) { // NOLINT(*) - if (str.length() == 0) return nullptr; - return &str[0]; -} -inline const char* BeginPtr(const std::string &str) { - if (str.length() == 0) return nullptr; - return &str[0]; -} -} // namespace rabit -#endif // RABIT_INTERNAL_UTILS_H_ diff --git a/rabit/include/rabit/rabit.h b/rabit/include/rabit/rabit.h deleted file mode 100644 index 10ea9a47f858..000000000000 --- a/rabit/include/rabit/rabit.h +++ /dev/null @@ -1,237 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file rabit.h - * \brief This file defines rabit's Allreduce/Broadcast interface - * The rabit engine contains the actual implementation - * Code that only uses this header can also be compiled with MPI Allreduce (non fault-tolerant), - * - * rabit.h and serializable.h is all what the user needs to use the rabit interface - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#ifndef RABIT_RABIT_H_ // NOLINT(*) -#define RABIT_RABIT_H_ // NOLINT(*) -#include -#include -#include -// engine definition of rabit, defines internal implementation -// to use rabit interface, there is no need to read engine.h -// rabit.h and serializable.h are enough to use the interface -#include "./internal/engine.h" - -/*! \brief rabit namespace */ -namespace rabit { -/*! - * \brief defines stream used in rabit - * see definition of Stream in dmlc/io.h - */ -using Stream = dmlc::Stream; -/*! - * \brief defines serializable objects used in rabit - * see definition of Serializable in dmlc/io.h - */ -using Serializable = dmlc::Serializable; - -/*! - * \brief reduction operators namespace - */ -namespace op { -/*! - * \class rabit::op::Max - * \brief maximum reduction operator - */ -struct Max; -/*! - * \class rabit::op::Min - * \brief minimum reduction operator - */ -struct Min; -/*! - * \class rabit::op::Sum - * \brief sum reduction operator - */ -struct Sum; -/*! - * \class rabit::op::BitAND - * \brief bitwise AND reduction operator - */ -struct BitAND; -/*! - * \class rabit::op::BitOR - * \brief bitwise OR reduction operator - */ -struct BitOR; -/*! - * \class rabit::op::BitXOR - * \brief bitwise XOR reduction operator - */ -struct BitXOR; -} // namespace op -/*! - * \brief initializes rabit, call this once at the beginning of your program - * \param argc number of arguments in argv - * \param argv the array of input arguments - * \return true if initialized successfully, otherwise false - */ -inline bool Init(int argc, char *argv[]); -/*! - * \brief finalizes the rabit engine, call this function after you finished with all the jobs - * \return true if finalized successfully, otherwise false - */ -inline bool Finalize(); -/*! \brief gets rank of the current process - * \return rank number of worker*/ -inline int GetRank(); -/*! \brief gets total number of processes - * \return total world size*/ -inline int GetWorldSize(); -/*! \brief whether rabit env is in distributed mode - * \return is distributed*/ -inline bool IsDistributed(); - -/*! \brief gets processor's name - * \return processor name*/ -inline std::string GetProcessorName(); -/*! - * \brief prints the msg to the tracker, - * this function can be used to communicate progress information to - * the user who monitors the tracker - * \param msg the message to be printed - */ -inline void TrackerPrint(const std::string &msg); - -#ifndef RABIT_STRICT_CXX98_ -/*! - * \brief prints the msg to the tracker, this function may not be available - * in very strict c++98 compilers, though it usually is. - * this function can be used to communicate progress information to - * the user who monitors the tracker - * \param fmt the format string - */ -inline void TrackerPrintf(const char *fmt, ...); -#endif // RABIT_STRICT_CXX98_ -/*! - * \brief broadcasts a memory region to every node from the root - * - * Example: int a = 1; Broadcast(&a, sizeof(a), root); - * \param sendrecv_data the pointer to the send/receive buffer, - * \param size the data size - * \param root the process root - */ -inline void Broadcast(void *sendrecv_data, size_t size, int root); - -/*! - * \brief broadcasts an std::vector to every node from root - * \param sendrecv_data the pointer to send/receive vector, - * for the receiver, the vector does not need to be pre-allocated - * \param root the process root - * \tparam DType the data type stored in the vector, has to be a simple data type - * that can be directly transmitted by sending the sizeof(DType) - */ -template -inline void Broadcast(std::vector *sendrecv_data, int root); -/*! - * \brief broadcasts a std::string to every node from the root - * \param sendrecv_data the pointer to the send/receive buffer, - * for the receiver, the vector does not need to be pre-allocated - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - * \param root the process root - */ -inline void Broadcast(std::string *sendrecv_data, int root); -/*! - * \brief performs in-place Allreduce on sendrecvbuf - * this function is NOT thread-safe - * - * Example Usage: the following code does an Allreduce and outputs the sum as the result - * \code{.cpp} - * vector data(10); - * ... - * Allreduce(&data[0], data.size()); - * ... - * \endcode - * - * \param sendrecvbuf buffer for both sending and receiving data - * \param count number of elements to be reduced - * \param prepare_fun Lazy preprocessing function, if it is not NULL, prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce in order to initialize the data in sendrecvbuf. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to pass into the lazy preprocessing function - * \tparam OP see namespace op, reduce operator - * \tparam DType data type - */ -template -inline void Allreduce(DType *sendrecvbuf, size_t count, - void (*prepare_fun)(void *) = nullptr, - void *prepare_arg = nullptr); - -/*! -* \brief Allgather function, each node have a segment of data in the ring of sendrecvbuf, -* the data provided by current node k is [slice_begin, slice_end), -* the next node's segment must start with slice_end -* after the call of Allgather, sendrecvbuf_ contains all the contents including all segments -* use a ring based algorithm -* -* \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually -* \param total_size total size of data to be gathered -* \param slice_begin beginning of the current slice -* \param slice_end end of the current slice -* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size -*/ -template -inline void Allgather(DType *sendrecvbuf_, - size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice); - -// C++11 support for lambda prepare function -#if DMLC_USE_CXX11 -/*! - * \brief performs in-place Allreduce, on sendrecvbuf - * with a prepare function specified by a lambda function - * - * Example Usage: - * \code{.cpp} - * // the following code does an Allreduce and outputs the sum as the result - * vector data(10); - * ... - * Allreduce(&data[0], data.size(), [&]() { - * for (int i = 0; i < 10; ++i) { - * data[i] = i; - * } - * }); - * ... - * \endcode - * \param sendrecvbuf buffer for both sending and receiving data - * \param count number of elements to be reduced - * \param prepare_fun Lazy lambda preprocessing function, prepare_fun() will be invoked - * by the function before performing Allreduce in order to initialize the data in sendrecvbuf. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \tparam OP see namespace op, reduce operator - * \tparam DType data type - */ -template -inline void Allreduce(DType *sendrecvbuf, size_t count, - std::function prepare_fun); -#endif // C++11 - -/*! - * \brief deprecated, planned for removal after checkpoing from JVM package is removed. - */ -inline int LoadCheckPoint(); -/*! - * \brief deprecated, planned for removal after checkpoing from JVM package is removed. - */ -inline void CheckPoint(); - -/*! - * \return version number of the current stored model, - * which means how many calls to CheckPoint we made so far - * \sa LoadCheckPoint, CheckPoint - */ -inline int VersionNumber(); -} // namespace rabit -// implementation of template functions -#include "./internal/rabit-inl.h" -#endif // RABIT_RABIT_H_ // NOLINT(*) diff --git a/rabit/include/rabit/serializable.h b/rabit/include/rabit/serializable.h deleted file mode 100644 index 77508292a986..000000000000 --- a/rabit/include/rabit/serializable.h +++ /dev/null @@ -1,26 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file serializable.h - * \brief defines serializable interface of rabit - * \author Tianqi Chen - */ -#ifndef RABIT_SERIALIZABLE_H_ -#define RABIT_SERIALIZABLE_H_ -#include -#include -#include "rabit/internal/utils.h" - -namespace rabit { -/*! - * \brief defines stream used in rabit - * see definition of Stream in dmlc/io.h - */ -using Stream = dmlc::Stream ; -/*! - * \brief defines serializable objects used in rabit - * see definition of Serializable in dmlc/io.h - */ -using Serializable = dmlc::Serializable; - -} // namespace rabit -#endif // RABIT_SERIALIZABLE_H_ diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc deleted file mode 100644 index b99eb3763a3e..000000000000 --- a/rabit/src/allreduce_base.cc +++ /dev/null @@ -1,997 +0,0 @@ -/** - * Copyright 2014-2023, XGBoost Contributors - * \file allreduce_base.cc - * \brief Basic implementation of AllReduce - * - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#if !defined(NOMINMAX) && defined(_WIN32) -#define NOMINMAX -#endif // !defined(NOMINMAX) - -#include "allreduce_base.h" - -#include "rabit/base.h" -#include "rabit/internal/rabit-inl.h" -#include "xgboost/collective/result.h" - -#ifndef _WIN32 -#include -#endif // _WIN32 - -#include -#include - -namespace rabit::engine { -// constructor -AllreduceBase::AllreduceBase() { - tracker_uri = "NULL"; - tracker_port = 9000; - host_uri = ""; - rank = 0; - world_size = -1; - connect_retry = 5; - hadoop_mode = false; - version_number = 0; - // 32 K items - reduce_ring_mincount = 32 << 10; - // 1M reducer size each time - tree_reduce_minsize = 1 << 20; - // tracker URL - task_id = "NULL"; - err_link = nullptr; - dmlc_role = "worker"; - this->SetParam("rabit_reduce_buffer", "256MB"); - // setup possible environment variable of interest - // include dmlc support direct variables - env_vars.emplace_back("DMLC_TASK_ID"); - env_vars.emplace_back("DMLC_ROLE"); - env_vars.emplace_back("DMLC_NUM_ATTEMPT"); - env_vars.emplace_back("DMLC_TRACKER_URI"); - env_vars.emplace_back("DMLC_TRACKER_PORT"); - env_vars.emplace_back("DMLC_WORKER_CONNECT_RETRY"); -} - -// initialization function -bool AllreduceBase::Init(int argc, char* argv[]) { - // setup from environment variables - // handler to get variables from env - for (auto & env_var : env_vars) { - const char *value = getenv(env_var.c_str()); - if (value != nullptr) { - this->SetParam(env_var.c_str(), value); - } - } - // pass in arguments override env variable. - for (int i = 0; i < argc; ++i) { - char name[256], val[256]; - if (sscanf(argv[i], "%[^=]=%s", name, val) == 2) { - this->SetParam(name, val); - } - } - - { - // handling for hadoop - const char *task_id = getenv("mapred_tip_id"); - if (task_id == nullptr) { - task_id = getenv("mapreduce_task_id"); - } - if (hadoop_mode) { - utils::Check(task_id != nullptr, - "hadoop_mode is set but cannot find mapred_task_id"); - } - if (task_id != nullptr) { - this->SetParam("rabit_task_id", task_id); - this->SetParam("rabit_hadoop_mode", "1"); - } - const char *attempt_id = getenv("mapred_task_id"); - if (attempt_id != nullptr) { - const char *att = strrchr(attempt_id, '_'); - int num_trial; - if (att != nullptr && sscanf(att + 1, "%d", &num_trial) == 1) { - this->SetParam("rabit_num_trial", att + 1); - } - } - // handling for hadoop - const char *num_task = getenv("mapred_map_tasks"); - if (num_task == nullptr) { - num_task = getenv("mapreduce_job_maps"); - } - if (hadoop_mode) { - utils::Check(num_task != nullptr, - "hadoop_mode is set but cannot find mapred_map_tasks"); - } - if (num_task != nullptr) { - this->SetParam("rabit_world_size", num_task); - } - } - if (dmlc_role != "worker") { - LOG(FATAL) << "Rabit Module currently only works with dmlc worker"; - } - - // clear the setting before start reconnection - this->rank = -1; - //--------------------- - // start socket - xgboost::system::SocketStartup(); - utils::Assert(all_links.size() == 0, "can only call Init once"); - auto rc = xgboost::collective::GetHostName(&this->host_uri); - if (!rc.OK()) { - LOG(FATAL) << rc.Report(); - } - // get information from tracker - rc = this->ReConnectLinks(); - if (rc.OK()) { - return true; - } - LOG(FATAL) << rc.Report(); - return false; -} - -bool AllreduceBase::Shutdown() { - try { - for (auto &all_link : all_links) { - if (!all_link.sock.IsClosed()) { - all_link.sock.Close(); - } - } - all_links.clear(); - tree_links.plinks.clear(); - - if (tracker_uri == "NULL") return true; - // notify tracker rank i have shutdown - xgboost::collective::TCPSocket tracker; - auto rc = this->ConnectTracker(&tracker); - if (!rc.OK()) { - LOG(FATAL) << rc.Report(); - } - tracker.Send(xgboost::StringView{"shutdown"}); - tracker.Close(); - xgboost::system::SocketFinalize(); - return true; - } catch (std::exception const &e) { - LOG(WARNING) << "Failed to shutdown due to" << e.what(); - return false; - } -} - -void AllreduceBase::TrackerPrint(const std::string &msg) { - if (tracker_uri == "NULL") { - utils::Printf("%s", msg.c_str()); return; - } - xgboost::collective::TCPSocket tracker; - auto rc = this->ConnectTracker(&tracker); - if (!rc.OK()) { - LOG(FATAL) << rc.Report(); - } - - tracker.Send(xgboost::StringView{"print"}); - tracker.Send(xgboost::StringView{msg}); - tracker.Close(); -} - -// util to parse data with unit suffix -inline size_t ParseUnit(const char *name, const char *val) { - char unit; - unsigned long amt; // NOLINT(*) - int n = sscanf(val, "%lu%c", &amt, &unit); - size_t amount = amt; - if (n == 2) { - switch (unit) { - case 'B': return amount; - case 'K': return amount << 10UL; - case 'M': return amount << 20UL; - case 'G': return amount << 30UL; - default: utils::Error("invalid format for %s", name); return 0; - } - } else if (n == 1) { - return amount; - } else { - utils::Error("invalid format for %s," \ - "shhould be {integer}{unit}, unit can be {B, KB, MB, GB}", name); - return 0; - } -} -/*! - * \brief set parameters to the engine - * \param name parameter name - * \param val parameter value - */ -void AllreduceBase::SetParam(const char *name, const char *val) { - if (!strcmp(name, "rabit_tracker_uri")) tracker_uri = val; - if (!strcmp(name, "rabit_tracker_port")) tracker_port = atoi(val); - if (!strcmp(name, "rabit_task_id")) task_id = val; - if (!strcmp(name, "DMLC_TRACKER_URI")) tracker_uri = val; - if (!strcmp(name, "DMLC_TRACKER_PORT")) tracker_port = atoi(val); - if (!strcmp(name, "DMLC_TASK_ID")) task_id = val; - if (!strcmp(name, "DMLC_ROLE")) dmlc_role = val; - if (!strcmp(name, "rabit_world_size")) world_size = atoi(val); - if (!strcmp(name, "rabit_hadoop_mode")) hadoop_mode = utils::StringToBool(val); - if (!strcmp(name, "rabit_tree_reduce_minsize")) tree_reduce_minsize = atoi(val); - if (!strcmp(name, "rabit_reduce_ring_mincount")) { - reduce_ring_mincount = atoi(val); - utils::Assert(reduce_ring_mincount > 0, "rabit_reduce_ring_mincount should be greater than 0"); - } - if (!strcmp(name, "rabit_reduce_buffer")) { - reduce_buffer_size = (ParseUnit(name, val) + 7) >> 3; - } - if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) { - connect_retry = atoi(val); - } - if (!strcmp(name, "rabit_timeout")) { - rabit_timeout = utils::StringToBool(val); - } - if (!strcmp(name, "rabit_timeout_sec")) { - timeout_sec = std::chrono::seconds(atoi(val)); - utils::Assert(timeout_sec.count() >= 0, "rabit_timeout_sec should be non negative second"); - } - if (!strcmp(name, "rabit_enable_tcp_no_delay")) { - if (!strcmp(val, "true")) { - rabit_enable_tcp_no_delay = true; - } else { - rabit_enable_tcp_no_delay = false; - } - } -} - -/*! - * \brief initialize connection to the tracker - * \return a socket that initializes the connection - */ -[[nodiscard]] xgboost::collective::Result AllreduceBase::ConnectTracker( - xgboost::collective::TCPSocket *out) const { - int magic = kMagic; - // get information from tracker - xgboost::collective::TCPSocket &tracker = *out; - - auto rc = - Connect(xgboost::StringView{tracker_uri}, tracker_port, connect_retry, timeout_sec, &tracker); - if (!rc.OK()) { - return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc)); - } - - using utils::Assert; - if (tracker.SendAll(&magic, sizeof(magic)) != sizeof(magic)) { - return xgboost::collective::Fail("Failed to send the verification number."); - } - if (tracker.RecvAll(&magic, sizeof(magic)) != sizeof(magic)) { - return xgboost::collective::Fail("Failed to recieve the verification number."); - } - if (magic != kMagic) { - return xgboost::collective::Fail("Invalid verification number."); - } - if (tracker.SendAll(&rank, sizeof(rank)) != sizeof(rank)) { - return xgboost::collective::Fail("Failed to send the local rank back to the tracker."); - } - if (tracker.SendAll(&world_size, sizeof(world_size)) != sizeof(world_size)) { - return xgboost::collective::Fail("Failed to send the world size back to the tracker."); - } - if (tracker.Send(xgboost::StringView{task_id}) != task_id.size()) { - return xgboost::collective::Fail("Failed to send the task ID back to the tracker."); - } - - return xgboost::collective::Success(); -} -/*! - * \brief connect to the tracker to fix the missing links - * this function is also used when the engine start up - */ -[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) { - // single node mode - if (tracker_uri == "NULL") { - rank = 0; - world_size = 1; - return xgboost::collective::Success(); - } - - xgboost::collective::TCPSocket tracker; - auto rc = this->ConnectTracker(&tracker); - if (!rc.OK()) { - return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc)); - } - - LOG(INFO) << "task " << task_id << " connected to the tracker"; - tracker.Send(xgboost::StringView{cmd}); - - try { - // the rank of previous link, next link in ring - int prev_rank, next_rank; - // the rank of neighbors - std::map tree_neighbors; - using utils::Assert; - // get new ranks - int newrank, num_neighbors; - Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), - "ReConnectLink failure 4"); - Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \ - sizeof(parent_rank), "ReConnectLink failure 4"); - Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), - "ReConnectLink failure 4"); - Assert(rank == -1 || newrank == rank, - "must keep rank to same if the node already have one"); - rank = newrank; - - if (rank == -1) { - LOG(FATAL) << "tracker got overwhelmed and not able to assign correct rank"; - } - - LOG(CONSOLE) << "task " << task_id << " got new rank " << rank; - - Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \ - sizeof(num_neighbors), "ReConnectLink failure 4"); - for (int i = 0; i < num_neighbors; ++i) { - int nrank; - Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank), - "ReConnectLink failure 4"); - tree_neighbors[nrank] = 1; - } - Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank), - "ReConnectLink failure 4"); - Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank), - "ReConnectLink failure 4"); - - auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())}; - // create listening socket - int port = sock_listen.BindHost(); - utils::Check(port != -1, "ReConnectLink fail to bind the ports specified"); - sock_listen.Listen(); - - // get number of to connect and number of to accept nodes from tracker - int num_conn, num_accept, num_error = 1; - do { - for (auto & all_link : all_links) { - all_link.sock.Close(); - } - // tracker construct goodset - Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), - "ReConnectLink failure 7"); - Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept), - "ReConnectLink failure 8"); - num_error = 0; - for (int i = 0; i < num_conn; ++i) { - LinkRecord r; - int hport, hrank; - std::string hname; - tracker.Recv(&hname); - Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); - Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); - // connect to peer - if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry, - timeout_sec, &r.sock) - .OK()) { - num_error += 1; - r.sock.Close(); - continue; - } - Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), - "ReConnectLink failure 12"); - Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), - "ReConnectLink failure 13"); - utils::Check(hrank == r.rank, - "ReConnectLink failure, link rank inconsistent"); - bool match = false; - for (auto & all_link : all_links) { - if (all_link.rank == hrank) { - Assert(all_link.sock.IsClosed(), "Override a link that is active"); - all_link.sock = std::move(r.sock); - match = true; - break; - } - } - if (!match) all_links.emplace_back(std::move(r)); - } - Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), - "ReConnectLink failure 14"); - } while (num_error != 0); - // send back socket listening port to tracker - Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14"); - // close connection to tracker - tracker.Close(); - - // listen to incoming links - for (int i = 0; i < num_accept; ++i) { - LinkRecord r; - r.sock = sock_listen.Accept(); - Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), - "ReConnectLink failure 15"); - Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), - "ReConnectLink failure 15"); - bool match = false; - for (auto & all_link : all_links) { - if (all_link.rank == r.rank) { - utils::Assert(all_link.sock.IsClosed(), - "Override a link that is active"); - all_link.sock = std::move(r.sock); - match = true; - break; - } - } - if (!match) all_links.emplace_back(std::move(r)); - } - sock_listen.Close(); - - this->parent_index = -1; - // setup tree links and ring structure - tree_links.plinks.clear(); - for (auto &all_link : all_links) { - utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket"); - // set the socket to non-blocking mode, enable TCP keepalive - CHECK(all_link.sock.NonBlocking(true).OK()); - CHECK(all_link.sock.SetKeepAlive().OK()); - if (rabit_enable_tcp_no_delay) { - CHECK(all_link.sock.SetNoDelay().OK()); - } - if (tree_neighbors.count(all_link.rank) != 0) { - if (all_link.rank == parent_rank) { - parent_index = static_cast(tree_links.plinks.size()); - } - tree_links.plinks.push_back(&all_link); - } - if (all_link.rank == prev_rank) ring_prev = &all_link; - if (all_link.rank == next_rank) ring_next = &all_link; - } - Assert(parent_rank == -1 || parent_index != -1, - "cannot find parent in the link"); - Assert(prev_rank == -1 || ring_prev != nullptr, - "cannot find prev ring in the link"); - Assert(next_rank == -1 || ring_next != nullptr, - "cannot find next ring in the link"); - return xgboost::collective::Success(); - } catch (const std::exception& e) { - std::stringstream ss; - ss << "Failed in ReconnectLink " << e.what(); - return xgboost::collective::Fail(ss.str()); - } -} -/*! - * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure - * - * NOTE on Allreduce: - * The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce. - * It only means the current node get the correct result of Allreduce. - * However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceBase::ReturnType -AllreduceBase::TryAllreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - if (count > reduce_ring_mincount) { - return this->TryAllreduceRing(sendrecvbuf_, type_nbytes, count, reducer); - } else { - return this->TryAllreduceTree(sendrecvbuf_, type_nbytes, count, reducer); - } -} -/*! - * \brief perform in-place allreduce, on sendrecvbuf, - * this function implements tree-shape reduction - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceBase::ReturnType -AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - RefLinkVector &links = tree_links; - if (links.Size() == 0 || count == 0) return kSuccess; - // total size of message - const size_t total_size = type_nbytes * count; - // number of links - const int nlink = static_cast(links.Size()); - // send recv buffer - char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); - // size of space that we already performs reduce in up pass - size_t size_up_reduce = 0; - // size of space that we have already passed to parent - size_t size_up_out = 0; - // size of message we received, and send in the down pass - size_t size_down_in = 0; - // minimal size of each reducer - const size_t eachreduce = (tree_reduce_minsize / type_nbytes * type_nbytes); - - // initialize the link ring-buffer and pointer - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - links[i].InitBuffer(type_nbytes, count, reduce_buffer_size); - } - links[i].ResetSize(); - } - // if no children, no need to reduce - if (nlink == static_cast(parent_index != -1)) { - size_up_reduce = total_size; - } - // while we have not passed the messages out - while (true) { - // select helper - bool finished = true; - utils::PollHelper watcher; - for (int i = 0; i < nlink; ++i) { - if (i == parent_index) { - if (size_down_in != total_size) { - watcher.WatchRead(links[i].sock); - // only watch for exception in live channels - watcher.WatchException(links[i].sock); - finished = false; - } - if (size_up_out != total_size && size_up_out < size_up_reduce) { - watcher.WatchWrite(links[i].sock); - } - } else { - if (links[i].size_read != total_size) { - watcher.WatchRead(links[i].sock); - } - // size_write <= size_read - if (links[i].size_write != total_size) { - if (links[i].size_write < size_down_in) { - watcher.WatchWrite(links[i].sock); - } - // only watch for exception in live channels - watcher.WatchException(links[i].sock); - finished = false; - } - } - } - // finish running allreduce - if (finished) { - break; - } - // select must return - auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos - if (!poll_res.OK()) { - LOG(FATAL) << poll_res.Report(); - } - - // read data from childs - for (int i = 0; i < nlink; ++i) { - if (i != parent_index && watcher.CheckRead(links[i].sock)) { - // make sure to receive minimal reducer size - // since each child reduce and sends the minimal reducer size - while (links[i].size_read < total_size - && links[i].size_read - size_up_reduce < eachreduce) { - ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - } - } - } - // this node have children, perform reduce - if (nlink > static_cast(parent_index != -1)) { - size_t buffer_size = 0; - // do upstream reduce - size_t max_reduce = total_size; - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - max_reduce = std::min(max_reduce, links[i].size_read); - utils::Assert(buffer_size == 0 || buffer_size == links[i].buffer_size, - "buffer size inconsistent"); - buffer_size = links[i].buffer_size; - } - } - utils::Assert(buffer_size != 0, "must assign buffer_size"); - // round to type_n4bytes - max_reduce = (max_reduce / type_nbytes * type_nbytes); - - // if max reduce is less than total size, we reduce multiple times of - // each reduce size - if (max_reduce < total_size) { - max_reduce = max_reduce - max_reduce % eachreduce; - } - - // perform reduce, can be at most two rounds - while (size_up_reduce < max_reduce) { - // start position - size_t start = size_up_reduce % buffer_size; - // perform read till end of buffer - size_t nread = std::min(buffer_size - start, - max_reduce - size_up_reduce); - utils::Assert(nread % type_nbytes == 0, "Allreduce: size check"); - for (int i = 0; i < nlink; ++i) { - if (i != parent_index) { - reducer(links[i].buffer_head + start, - sendrecvbuf + size_up_reduce, - static_cast(nread / type_nbytes), - MPI::Datatype(type_nbytes)); - } - } - size_up_reduce += nread; - } - } - if (parent_index != -1) { - // pass message up to parent, can pass data that are already been reduced - if (size_up_out < size_up_reduce) { - ssize_t len = links[parent_index].sock. - Send(sendrecvbuf + size_up_out, size_up_reduce - size_up_out); - if (len != -1) { - size_up_out += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) { - return ReportError(&links[parent_index], ret); - } - } - } - // read data from parent - if (watcher.CheckRead(links[parent_index].sock) && - total_size > size_down_in) { - size_t left_size = total_size-size_down_in; - size_t reduce_size_min = std::min(left_size, eachreduce); - size_t recved = 0; - while (recved < reduce_size_min) { - ssize_t len = links[parent_index].sock. - Recv(sendrecvbuf + size_down_in, total_size - size_down_in); - - if (len == 0) { - links[parent_index].sock.Close(); - return ReportError(&links[parent_index], kRecvZeroLen); - } - if (len != -1) { - size_down_in += static_cast(len); - utils::Assert(size_down_in <= size_up_out, - "Allreduce: boundary error"); - recved+=len; - - // if it receives more data than each reduce, it means the next block is sent. - // we double the reduce_size_min or add to left_size - while (recved > reduce_size_min) { - reduce_size_min += std::min(left_size-reduce_size_min, eachreduce); - } - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) { - return ReportError(&links[parent_index], ret); - } - } - } - } - } else { - // this is root, can use reduce as most recent point - size_down_in = size_up_out = size_up_reduce; - } - // can pass message down to children - for (int i = 0; i < nlink; ++i) { - if (i != parent_index && links[i].size_write < size_down_in) { - ReturnType ret = links[i].WriteFromArray(sendrecvbuf, size_down_in); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - } - } - } - return kSuccess; -} -/*! - * \brief broadcast data from root to all nodes, this function can fail,and will return the cause of failure - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param total_size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceBase::ReturnType -AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { - RefLinkVector &links = tree_links; - if (links.Size() == 0 || total_size == 0) return kSuccess; - utils::Check(root < world_size, - "Broadcast: root should be smaller than world size"); - // number of links - const int nlink = static_cast(links.Size()); - // size of space already read from data - size_t size_in = 0; - // input link, -2 means unknown yet, -1 means this is root - int in_link = -2; - - // initialize the link statistics - for (int i = 0; i < nlink; ++i) { - links[i].ResetSize(); - } - // root have all the data - if (this->rank == root) { - size_in = total_size; - in_link = -1; - } - // while we have not passed the messages out - while (true) { - bool finished = true; - // select helper - utils::PollHelper watcher; - for (int i = 0; i < nlink; ++i) { - if (in_link == -2) { - watcher.WatchRead(links[i].sock); finished = false; - } - if (i == in_link && links[i].size_read != total_size) { - watcher.WatchRead(links[i].sock); finished = false; - } - if (in_link != -2 && i != in_link && links[i].size_write != total_size) { - if (links[i].size_write < size_in) { - watcher.WatchWrite(links[i].sock); - } - finished = false; - } - } - // finish running - if (finished) break; - // select - auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos - if (!poll_res.OK()) { - LOG(FATAL) << poll_res.Report(); - } - if (in_link == -2) { - // probe in-link - for (int i = 0; i < nlink; ++i) { - if (watcher.CheckRead(links[i].sock)) { - ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - size_in = links[i].size_read; - if (size_in != 0) { - in_link = i; break; - } - } - } - } else { - // read from in link - if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) { - ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size); - if (ret != kSuccess) { - return ReportError(&links[in_link], ret); - } - size_in = links[in_link].size_read; - } - } - // send data to all out-link - for (int i = 0; i < nlink; ++i) { - if (i != in_link && links[i].size_write < size_in) { - ReturnType ret = links[i].WriteFromArray(sendrecvbuf_, size_in); - if (ret != kSuccess) { - return ReportError(&links[i], ret); - } - } - } - } - return kSuccess; -} -/*! - * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - */ -AllreduceBase::ReturnType -AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice) { - // read from next link and send to prev one - LinkRecord &prev = *ring_prev, &next = *ring_next; - // need to reply on special rank structure - utils::Assert(next.rank == (rank + 1) % world_size && - rank == (prev.rank + 1) % world_size, - "need to assume rank structure"); - // send recv buffer - char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); - const size_t stop_read = total_size + slice_begin; - const size_t stop_write = total_size + slice_begin - size_prev_slice; - size_t write_ptr = slice_begin; - size_t read_ptr = slice_end; - - while (true) { - // select helper - bool finished = true; - utils::PollHelper watcher; - if (read_ptr != stop_read) { - watcher.WatchRead(next.sock); - finished = false; - } - if (write_ptr != stop_write) { - if (write_ptr < read_ptr) { - watcher.WatchWrite(prev.sock); - } - finished = false; - } - if (finished) { - break; - } - - auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos - if (!poll_res.OK()) { - LOG(FATAL) << poll_res.Report(); - } - if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { - size_t size = stop_read - read_ptr; - size_t start = read_ptr % total_size; - if (start + size > total_size) { - size = total_size - start; - } - ssize_t len = next.sock.Recv(sendrecvbuf + start, size); - if (len != -1) { - read_ptr += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) { - auto err = ReportError(&next, ret); - return err; - } - } - } - if (write_ptr < read_ptr && write_ptr != stop_write) { - size_t size = std::min(read_ptr, stop_write) - write_ptr; - size_t start = write_ptr % total_size; - if (start + size > total_size) { - size = total_size - start; - } - ssize_t len = prev.sock.Send(sendrecvbuf + start, size); - if (len != -1) { - write_ptr += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) { - auto err = ReportError(&prev, ret); - return err; - } - } - } - } - return kSuccess; -} -/*! - * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, - * and will return the cause of failure - * - * Ring-based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType, TryAllreduce - */ -AllreduceBase::ReturnType -AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - // read from next link and send to prev one - LinkRecord &prev = *ring_prev, &next = *ring_next; - // need to reply on special rank structure - utils::Assert(next.rank == (rank + 1) % world_size && - rank == (prev.rank + 1) % world_size, - "need to assume rank structure"); - // total size of message - const size_t total_size = type_nbytes * count; - size_t n = static_cast(world_size); - size_t step = (count + n - 1) / n; - size_t r = static_cast(next.rank); - size_t write_ptr = std::min(r * step, count) * type_nbytes; - size_t read_ptr = std::min((r + 1) * step, count) * type_nbytes; - size_t reduce_ptr = read_ptr; - // send recv buffer - char *sendrecvbuf = reinterpret_cast(sendrecvbuf_); - // position to stop reading - const size_t stop_read = total_size + write_ptr; - // position to stop writing - size_t stop_write = total_size + std::min(rank * step, count) * type_nbytes; - if (stop_write > stop_read) { - stop_write -= total_size; - utils::Assert(write_ptr <= stop_write, "write ptr boundary check"); - } - // use ring buffer in next position - next.InitBuffer(type_nbytes, step, reduce_buffer_size); - // set size_read to read pointer for ring buffer to work properly - next.size_read = read_ptr; - - while (true) { - // select helper - bool finished = true; - utils::PollHelper watcher; - if (read_ptr != stop_read) { - watcher.WatchRead(next.sock); - finished = false; - } - if (write_ptr != stop_write) { - if (write_ptr < reduce_ptr) { - watcher.WatchWrite(prev.sock); - } - finished = false; - } - if (finished) { - break; - } - auto poll_res = watcher.Poll(timeout_sec, false); // fail on macos - if (!poll_res.OK()) { - LOG(FATAL) << poll_res.Report(); - } - if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { - ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read); - if (ret != kSuccess) { - return ReportError(&next, ret); - } - // sync the rate - read_ptr = next.size_read; - utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank); - const size_t buffer_size = next.buffer_size; - size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes; - while (reduce_ptr < max_reduce) { - size_t bstart = reduce_ptr % buffer_size; - size_t nread = std::min(buffer_size - bstart, - max_reduce - reduce_ptr); - size_t rstart = reduce_ptr % total_size; - nread = std::min(nread, total_size - rstart); - reducer(next.buffer_head + bstart, - sendrecvbuf + rstart, - static_cast(nread / type_nbytes), - MPI::Datatype(type_nbytes)); - reduce_ptr += nread; - } - } - if (write_ptr < reduce_ptr && write_ptr != stop_write) { - size_t size = std::min(reduce_ptr, stop_write) - write_ptr; - size_t start = write_ptr % total_size; - if (start + size > total_size) { - size = total_size - start; - } - ssize_t len = prev.sock.Send(sendrecvbuf + start, size); - if (len != -1) { - write_ptr += static_cast(len); - } else { - ReturnType ret = Errno2Return(); - if (ret != kSuccess) return ReportError(&prev, ret); - } - } - } - return kSuccess; -} -/*! - * \brief perform in-place allreduce, on sendrecvbuf - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ -AllreduceBase::ReturnType -AllreduceBase::TryAllreduceRing(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer) { - ReturnType ret = TryReduceScatterRing(sendrecvbuf_, type_nbytes, count, reducer); - if (ret != kSuccess) return ret; - size_t n = static_cast(world_size); - size_t step = (count + n - 1) / n; - size_t begin = std::min(rank * step, count) * type_nbytes; - size_t end = std::min((rank + 1) * step, count) * type_nbytes; - // previous rank - int prank = ring_prev->rank; - // get rank of previous - return TryAllgatherRing - (sendrecvbuf_, type_nbytes * count, - begin, end, - (std::min((prank + 1) * step, count) - - std::min(prank * step, count)) * type_nbytes); -} -} // namespace rabit::engine diff --git a/rabit/src/allreduce_base.h b/rabit/src/allreduce_base.h deleted file mode 100644 index 7724bf3d58c3..000000000000 --- a/rabit/src/allreduce_base.h +++ /dev/null @@ -1,501 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file allreduce_base.h - * \brief Basic implementation of AllReduce - * using TCP non-block socket and tree-shape reduction. - * - * This implementation provides basic utility of AllReduce and Broadcast - * without considering node failure - * - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#ifndef RABIT_ALLREDUCE_BASE_H_ -#define RABIT_ALLREDUCE_BASE_H_ - -#include -#include -#include -#include -#include - -#include "rabit/internal/engine.h" -#include "rabit/internal/socket.h" -#include "rabit/internal/utils.h" -#include "xgboost/collective/result.h" - -#ifdef RABIT_CXXTESTDEFS_H -#define private public -#define protected public -#endif // RABIT_CXXTESTDEFS_H - - -namespace MPI { // NOLINT -// MPI data type to be compatible with existing MPI interface -class Datatype { - public: - size_t type_size; - explicit Datatype(size_t type_size) : type_size(type_size) {} -}; -} -namespace rabit { -namespace engine { - -/*! \brief implementation of basic Allreduce engine */ -class AllreduceBase : public IEngine { - public: - // magic number to verify server - static const int kMagic = 0xff99; - // constant one byte out of band message to indicate error happening - AllreduceBase(); - virtual ~AllreduceBase() = default; - // initialize the manager - virtual bool Init(int argc, char* argv[]); - // shutdown the engine - virtual bool Shutdown(); - /*! - * \brief set parameters to the engine - * \param name parameter name - * \param val parameter value - */ - virtual void SetParam(const char *name, const char *val); - /*! - * \brief print the msg in the tracker, - * this function can be used to communicate the information of the progress to - * the user who monitors the tracker - * \param msg message to be printed in the tracker - */ - void TrackerPrint(const std::string &msg) override; - - /*! \brief get rank of previous node in ring topology*/ - int GetRingPrevRank() const override { - return ring_prev->rank; - } - /*! \brief get rank */ - int GetRank() const override { - return rank; - } - /*! \brief get rank */ - int GetWorldSize() const override { - if (world_size == -1) return 1; - return world_size; - } - /*! \brief whether is distributed or not */ - bool IsDistributed() const override { - return tracker_uri != "NULL"; - } - /*! \brief get rank */ - std::string GetHost() const override { - return host_uri; - } - - /*! - * \brief internal Allgather function, each node has a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - */ - void Allgather(void *sendrecvbuf_, size_t total_size, size_t slice_begin, - size_t slice_end, size_t size_prev_slice) override { - if (world_size == 1 || world_size == -1) { - return; - } - utils::Assert(TryAllgatherRing(sendrecvbuf_, total_size, slice_begin, - slice_end, size_prev_slice) == kSuccess, - "AllgatherRing failed"); - } - /*! - * \brief perform in-place allreduce, on sendrecvbuf - * this function is NOT thread-safe - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \param prepare_func Lazy preprocessing function, lazy prepare_fun(prepare_arg) - * will be called by the function before performing Allreduce, to initialize the data in sendrecvbuf_. - * If the result of Allreduce can be recovered directly, then prepare_func will NOT be called - * \param prepare_arg argument used to passed into the lazy preprocessing function - */ - void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, - ReduceFunction reducer, PreprocFunction prepare_fun = nullptr, - void *prepare_arg = nullptr) override { - if (prepare_fun != nullptr) prepare_fun(prepare_arg); - if (world_size == 1 || world_size == -1) return; - utils::Assert(TryAllreduce(sendrecvbuf_, type_nbytes, count, reducer) == - kSuccess, - "Allreduce failed"); - } - /*! - * \brief broadcast data from root to all nodes - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - * \param _file caller file name used to generate unique cache key - * \param _line caller line number used to generate unique cache key - * \param _caller caller function name used to generate unique cache key - */ - void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override { - if (world_size == 1 || world_size == -1) return; - utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, - "Broadcast failed"); - } - /*! - * \brief deprecated - * \sa CheckPoint, VersionNumber - */ - int LoadCheckPoint() override { return 0; } - - // deprecated, increase internal version number - void CheckPoint() override { version_number += 1; } - /*! - * \return version number of current stored model, - * which means how many calls to CheckPoint we made so far - * \sa LoadCheckPoint, CheckPoint - */ - int VersionNumber() const override { - return version_number; - } - /*! - * \brief report current status to the job tracker - * depending on the job tracker we are in - */ - inline void ReportStatus() const { - if (hadoop_mode != 0) { - LOG(CONSOLE) << "reporter:status:Rabit Phase[" << version_number << "] Operation " << seq_counter << "\n"; - } - } - - protected: - /*! \brief enumeration of possible returning results from Try functions */ - enum ReturnTypeEnum { - /*! \brief execution is successful */ - kSuccess, - /*! \brief a link was reset by peer */ - kConnReset, - /*! \brief received a zero length message */ - kRecvZeroLen, - /*! \brief a neighbor node go down, the connection is dropped */ - kSockError, - /*! - * \brief another node which is not my neighbor go down, - * get Out-of-Band exception notification from my neighbor - */ - kGetExcept - }; - /*! \brief struct return type to avoid implicit conversion to int/bool */ - struct ReturnType { - /*! \brief internal return type */ - ReturnTypeEnum value; - // constructor - ReturnType() = default; - ReturnType(ReturnTypeEnum value) : value(value) {} // NOLINT(*) - inline bool operator==(const ReturnTypeEnum &v) const { - return value == v; - } - inline bool operator!=(const ReturnTypeEnum &v) const { - return value != v; - } - }; - /*! \brief translate errno to return type */ - static ReturnType Errno2Return() { - int errsv = xgboost::system::LastError(); - if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess; -#ifdef _WIN32 - if (errsv == WSAEWOULDBLOCK) return kSuccess; - if (errsv == WSAECONNRESET) return kConnReset; -#endif // _WIN32 - if (errsv == ECONNRESET) return kConnReset; - return kSockError; - } - // link record to a neighbor - struct LinkRecord { - public: - // socket to get data from/to link - xgboost::collective::TCPSocket sock; - // rank of the node in this link - int rank; - // size of data readed from link - size_t size_read; - // size of data sent to the link - size_t size_write; - // pointer to buffer head - char *buffer_head {nullptr}; - // buffer size, in bytes - size_t buffer_size {0}; - // constructor - LinkRecord() = default; - // initialize buffer - void InitBuffer(size_t type_nbytes, size_t count, - size_t reduce_buffer_size) { - size_t n = (type_nbytes * count + 7)/ 8; - auto to = Min(reduce_buffer_size, n); - buffer_.resize(to); - // make sure align to type_nbytes - buffer_size = - buffer_.size() * sizeof(uint64_t) / type_nbytes * type_nbytes; - utils::Assert(type_nbytes <= buffer_size, - "too large type_nbytes=%lu, buffer_size=%lu", - type_nbytes, buffer_size); - // set buffer head - buffer_head = reinterpret_cast(BeginPtr(buffer_)); - } - // reset the recv and sent size - inline void ResetSize() { - size_write = size_read = 0; - } - /*! - * \brief read data into ring-buffer, with care not to existing useful override data - * position after protect_start - * \param protect_start all data start from protect_start is still needed in buffer - * read shall not override this - * \param max_size_read maximum logical amount we can read, size_read cannot exceed this value - * \return the type of reading - */ - inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) { - utils::Assert(buffer_head != nullptr, "ReadToRingBuffer: buffer not allocated"); - utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check"); - size_t ngap = size_read - protect_start; - utils::Assert(ngap <= buffer_size, "Allreduce: boundary check"); - size_t offset = size_read % buffer_size; - size_t nmax = max_size_read - size_read; - nmax = Min(nmax, buffer_size - ngap); - nmax = Min(nmax, buffer_size - offset); - if (nmax == 0) return kSuccess; - ssize_t len = sock.Recv(buffer_head + offset, nmax); - // length equals 0, remote disconnected - if (len == 0) { - sock.Close(); return kRecvZeroLen; - } - if (len == -1) return Errno2Return(); - size_read += static_cast(len); - return kSuccess; - } - /*! - * \brief read data into array, - * this function can not be used together with ReadToRingBuffer - * a link can either read into the ring buffer, or existing array - * \param max_size maximum size of array - * \return true if it is a successful read, false if there is some error happens, check errno - */ - inline ReturnType ReadToArray(void *recvbuf_, size_t max_size) { - if (max_size == size_read) return kSuccess; - char *p = static_cast(recvbuf_); - ssize_t len = sock.Recv(p + size_read, max_size - size_read); - // length equals 0, remote disconnected - if (len == 0) { - sock.Close(); return kRecvZeroLen; - } - if (len == -1) return Errno2Return(); - size_read += static_cast(len); - return kSuccess; - } - /*! - * \brief write data in array to sock - * \param sendbuf_ head of array - * \param max_size maximum size of array - * \return true if it is a successful write, false if there is some error happens, check errno - */ - inline ReturnType WriteFromArray(const void *sendbuf_, size_t max_size) { - const char *p = static_cast(sendbuf_); - ssize_t len = sock.Send(p + size_write, max_size - size_write); - if (len == -1) return Errno2Return(); - size_write += static_cast(len); - return kSuccess; - } - - private: - // recv buffer to get data from child - // aligned with 64 bits, will be able to perform 64 bits operations freely - std::vector buffer_; - }; - /*! - * \brief simple data structure that works like a vector - * but takes reference instead of space - */ - struct RefLinkVector { - std::vector plinks; - inline LinkRecord &operator[](size_t i) { - return *plinks[i]; - } - inline size_t Size() const { - return plinks.size(); - } - }; - /*! - * \brief initialize connection to the tracker - * \return a socket that initializes the connection - */ - [[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const; - /*! - * \brief connect to the tracker to fix the missing links - * this function is also used when the engine start up - * \param cmd possible command to sent to tracker - */ - [[nodiscard]] xgboost::collective::Result ReConnectLinks(const char *cmd = "start"); - /*! - * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure - * - * NOTE on Allreduce: - * The kSuccess TryAllreduce does NOT mean every node have successfully finishes TryAllreduce. - * It only means the current node get the correct result of Allreduce. - * However, it means every node finishes LAST call(instead of this one) of Allreduce/Bcast - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryAllreduce(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer); - /*! - * \brief broadcast data from root to all nodes, this function can fail, and will return the cause of failure - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param size the size of the data to be broadcasted - * \param root the root worker id to broadcast the data - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root); - /*! - * \brief perform in-place allreduce, on sendrecvbuf, - * this function implements tree-shape reduction - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryAllreduceTree(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer); - /*! - * \brief internal Allgather function, each node have a segment of data in the ring of sendrecvbuf, - * the data provided by current node k is [slice_begin, slice_end), - * the next node's segment must start with slice_end - * after the call of Allgather, sendrecvbuf_ contains all the contents including all segments - * use a ring based algorithm - * - * \param sendrecvbuf_ buffer for both sending and receiving data, it is a ring conceptually - * \param total_size total size of data to be gathered - * \param slice_begin beginning of the current slice - * \param slice_end end of the current slice - * \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryAllgatherRing(void *sendrecvbuf_, size_t total_size, - size_t slice_begin, size_t slice_end, - size_t size_prev_slice); - /*! - * \brief perform in-place allreduce, reduce on the sendrecvbuf, - * - * after the function, node k get k-th segment of the reduction result - * the k-th segment is defined by [k * step, min((k + 1) * step,count) ) - * where step = ceil(count / world_size) - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType, TryAllreduce - */ - ReturnType TryReduceScatterRing(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer); - /*! - * \brief perform in-place allreduce, on sendrecvbuf - * use a ring based algorithm, reduce-scatter + allgather - * - * \param sendrecvbuf_ buffer for both sending and receiving data - * \param type_nbytes the unit number of bytes the type have - * \param count number of elements to be reduced - * \param reducer reduce function - * \return this function can return kSuccess, kSockError, kGetExcept, see ReturnType for details - * \sa ReturnType - */ - ReturnType TryAllreduceRing(void *sendrecvbuf_, - size_t type_nbytes, - size_t count, - ReduceFunction reducer); - /*! - * \brief function used to report error when a link goes wrong - * \param link the pointer to the link who causes the error - * \param err the error type - */ - inline ReturnType ReportError(LinkRecord *link, ReturnType err) { - err_link = link; return err; - } - //---- data structure related to model ---- - // call sequence counter, records how many calls we made so far - // from last call to CheckPoint, LoadCheckPoint - int seq_counter{0}; // NOLINT - // version number of model - int version_number {0}; // NOLINT - // whether the job is running in Hadoop - bool hadoop_mode; // NOLINT - //---- local data related to link ---- - // index of parent link, can be -1, meaning this is root of the tree - int parent_index; // NOLINT - // rank of parent node, can be -1 - int parent_rank; // NOLINT - // sockets of all links this connects to - std::vector all_links; // NOLINT - // used to record the link where things goes wrong - LinkRecord *err_link; // NOLINT - // all the links in the reduction tree connection - RefLinkVector tree_links; // NOLINT - // pointer to links in the ring - LinkRecord *ring_prev, *ring_next; // NOLINT - //----- meta information----- - // list of enviroment variables that are of possible interest - std::vector env_vars; // NOLINT - // unique identifier of the possible job this process is doing - // used to assign ranks, optional, default to NULL - std::string task_id; // NOLINT - // uri of current host, to be set by Init - std::string host_uri; // NOLINT - // uri of tracker - std::string tracker_uri; // NOLINT - // role in dmlc jobs - std::string dmlc_role; // NOLINT - // port of tracker address - int tracker_port; // NOLINT - // reduce buffer size - size_t reduce_buffer_size; // NOLINT - // reduction method - int reduce_method; // NOLINT - // minimum count of cells to use ring based method - size_t reduce_ring_mincount; // NOLINT - // minimum block size per tree reduce - size_t tree_reduce_minsize; // NOLINT - // current rank - int rank; // NOLINT - // world size - int world_size; // NOLINT - // connect retry time - int connect_retry; // NOLINT - // by default, if rabit worker not recover in half an hour exit - std::chrono::seconds timeout_sec{std::chrono::seconds{1800}}; // NOLINT - // flag to enable rabit_timeout - bool rabit_timeout = false; // NOLINT - // Enable TCP node delay - bool rabit_enable_tcp_no_delay = false; // NOLINT -}; -} // namespace engine -} // namespace rabit -#endif // RABIT_ALLREDUCE_BASE_H_ diff --git a/rabit/src/allreduce_mock.h b/rabit/src/allreduce_mock.h deleted file mode 100644 index b2434658688b..000000000000 --- a/rabit/src/allreduce_mock.h +++ /dev/null @@ -1,147 +0,0 @@ -/*! - * Copyright by Contributors - * \file allreduce_mock.h - * \brief Mock test module of AllReduce engine, - * insert failures in certain call point, to test if the engine is robust to failure - * - * \author Ignacio Cano, Tianqi Chen - */ -#ifndef RABIT_ALLREDUCE_MOCK_H_ -#define RABIT_ALLREDUCE_MOCK_H_ -#include -#include -#include -#include -#include "rabit/internal/engine.h" -#include "allreduce_base.h" - -namespace rabit { -namespace engine { -class AllreduceMock : public AllreduceBase { - public: - // constructor - AllreduceMock() { - num_trial_ = 0; - force_local_ = 0; - report_stats_ = 0; - tsum_allreduce_ = 0.0; - tsum_allgather_ = 0.0; - } - // destructor - ~AllreduceMock() override = default; - void SetParam(const char *name, const char *val) override { - AllreduceBase::SetParam(name, val); - // additional parameters - if (!strcmp(name, "rabit_num_trial")) num_trial_ = atoi(val); - if (!strcmp(name, "DMLC_NUM_ATTEMPT")) num_trial_ = atoi(val); - if (!strcmp(name, "report_stats")) report_stats_ = atoi(val); - if (!strcmp(name, "force_local")) force_local_ = atoi(val); - if (!strcmp(name, "mock")) { - MockKey k; - utils::Check(sscanf(val, "%d,%d,%d,%d", - &k.rank, &k.version, &k.seqno, &k.ntrial) == 4, - "invalid mock parameter"); - mock_map_[k] = 1; - } - } - void Allreduce(void *sendrecvbuf_, size_t type_nbytes, size_t count, - ReduceFunction reducer, PreprocFunction prepare_fun, - void *prepare_arg) override { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "AllReduce"); - double tstart = dmlc::GetTime(); - AllreduceBase::Allreduce(sendrecvbuf_, type_nbytes, count, reducer, - prepare_fun, prepare_arg); - tsum_allreduce_ += dmlc::GetTime() - tstart; - } - void Allgather(void *sendrecvbuf, size_t total_size, size_t slice_begin, - size_t slice_end, size_t size_prev_slice) override { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Allgather"); - double tstart = dmlc::GetTime(); - AllreduceBase::Allgather(sendrecvbuf, total_size, slice_begin, slice_end, - size_prev_slice); - tsum_allgather_ += dmlc::GetTime() - tstart; - } - void Broadcast(void *sendrecvbuf_, size_t total_size, int root) override { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "Broadcast"); - AllreduceBase::Broadcast(sendrecvbuf_, total_size, root); - } - int LoadCheckPoint() override { - tsum_allreduce_ = 0.0; - tsum_allgather_ = 0.0; - time_checkpoint_ = dmlc::GetTime(); - if (force_local_ == 0) { - return AllreduceBase::LoadCheckPoint(); - } else { - return AllreduceBase::LoadCheckPoint(); - } - } - void CheckPoint() override { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial_), "CheckPoint"); - double tstart = dmlc::GetTime(); - double tbet_chkpt = tstart - time_checkpoint_; - AllreduceBase::CheckPoint(); - time_checkpoint_ = dmlc::GetTime(); - double tcost = dmlc::GetTime() - tstart; - if (report_stats_ != 0 && rank == 0) { - std::stringstream ss; - ss << "[v" << version_number << "] global_size=" - << ",check_tcost="<< tcost <<" sec" - << ",allreduce_tcost=" << tsum_allreduce_ << " sec" - << ",allgather_tcost=" << tsum_allgather_ << " sec" - << ",between_chpt=" << tbet_chkpt << "sec\n"; - this->TrackerPrint(ss.str()); - } - tsum_allreduce_ = 0.0; - tsum_allgather_ = 0.0; - } - - protected: - // force checkpoint to local - int force_local_; - // whether report statistics - int report_stats_; - // sum of allreduce - double tsum_allreduce_; - // sum of allgather - double tsum_allgather_; - double time_checkpoint_; - - private: - // key to identify the mock stage - struct MockKey { - int rank; - int version; - int seqno; - int ntrial; - MockKey() = default; - MockKey(int rank, int version, int seqno, int ntrial) - : rank(rank), version(version), seqno(seqno), ntrial(ntrial) {} - inline bool operator==(const MockKey &b) const { - return rank == b.rank && - version == b.version && - seqno == b.seqno && - ntrial == b.ntrial; - } - inline bool operator<(const MockKey &b) const { - if (rank != b.rank) return rank < b.rank; - if (version != b.version) return version < b.version; - if (seqno != b.seqno) return seqno < b.seqno; - return ntrial < b.ntrial; - } - }; - // number of failure trials - int num_trial_; - // record all mock actions - std::map mock_map_; - // used to generate all kinds of exceptions - inline void Verify(const MockKey &key, const char *name) { - if (mock_map_.count(key) != 0) { - num_trial_ += 1; - // data processing frameworks runs on shared process - throw dmlc::Error(std::to_string(rank) + "@@@Hit Mock Error: " + name); - } - } -}; -} // namespace engine -} // namespace rabit -#endif // RABIT_ALLREDUCE_MOCK_H_ diff --git a/rabit/src/engine.cc b/rabit/src/engine.cc deleted file mode 100644 index 89f25fa1e013..000000000000 --- a/rabit/src/engine.cc +++ /dev/null @@ -1,106 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file engine.cc - * \brief this file governs which implementation of engine we are actually using - * provides an singleton of engine interface - * - * \author Tianqi Chen, Ignacio Cano, Tianyi Zhou - */ -#include -#include - -#include -#include "rabit/internal/engine.h" -#include "allreduce_base.h" - -namespace rabit { -namespace engine { -// singleton sync manager -#ifndef RABIT_USE_BASE -#ifndef RABIT_USE_MOCK -using Manager = AllreduceBase; -#else -typedef AllreduceMock Manager; -#endif // RABIT_USE_MOCK -#else -typedef AllreduceBase Manager; -#endif // RABIT_USE_BASE - -/*! \brief entry to to easily hold returning information */ -struct ThreadLocalEntry { - /*! \brief stores the current engine */ - std::unique_ptr engine; - /*! \brief whether init has been called */ - bool initialized{false}; - /*! \brief constructor */ - ThreadLocalEntry() = default; -}; - -// define the threadlocal store. -using EngineThreadLocal = dmlc::ThreadLocalStore; - -/*! \brief intiialize the synchronization module */ -bool Init(int argc, char *argv[]) { - ThreadLocalEntry* e = EngineThreadLocal::Get(); - if (e->engine.get() == nullptr) { - e->initialized = true; - e->engine.reset(new Manager()); - return e->engine->Init(argc, argv); - } else { - return true; - } -} - -/*! \brief finalize syncrhonization module */ -bool Finalize() { - ThreadLocalEntry* e = EngineThreadLocal::Get(); - if (e->engine.get() != nullptr) { - if (e->engine->Shutdown()) { - e->engine.reset(nullptr); - e->initialized = false; - return true; - } else { - return false; - } - } else { - return true; - } -} - -/*! \brief singleton method to get engine */ -IEngine *GetEngine() { - // un-initialized default manager. - static AllreduceBase default_manager; - ThreadLocalEntry* e = EngineThreadLocal::Get(); - IEngine* ptr = e->engine.get(); - if (ptr == nullptr) { - utils::Check(!e->initialized, "the rabit has not been initialized"); - return &default_manager; - } else { - return ptr; - } -} - -// perform in-place allgather, on sendrecvbuf -void Allgather(void *sendrecvbuf_, size_t total_size, - size_t slice_begin, - size_t slice_end, - size_t size_prev_slice) { - GetEngine()->Allgather(sendrecvbuf_, total_size, slice_begin, - slice_end, size_prev_slice); -} - - -// perform in-place allreduce, on sendrecvbuf -void Allreduce_(void *sendrecvbuf, // NOLINT - size_t type_nbytes, - size_t count, - IEngine::ReduceFunction red, - mpi::DataType, - mpi::OpType , - IEngine::PreprocFunction prepare_fun, - void *prepare_arg) { - GetEngine()->Allreduce(sendrecvbuf, type_nbytes, count, red, prepare_fun, prepare_arg); -} -} // namespace engine -} // namespace rabit diff --git a/rabit/src/engine_mock.cc b/rabit/src/engine_mock.cc deleted file mode 100644 index 5c0f8505e425..000000000000 --- a/rabit/src/engine_mock.cc +++ /dev/null @@ -1,14 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file engine_mock.cc - * \brief this is an engine implementation that will - * insert failures in certain call point, to test if the engine is robust to failure - * \author Tianqi Chen - */ -// define use MOCK, os we will use mock Manager -#define NOMINMAX -// switch engine to AllreduceMock -#define RABIT_USE_MOCK -#include -#include "allreduce_mock.h" -#include "engine.cc" diff --git a/rabit/src/rabit_c_api.cc b/rabit/src/rabit_c_api.cc deleted file mode 100644 index c90fae83052e..000000000000 --- a/rabit/src/rabit_c_api.cc +++ /dev/null @@ -1,342 +0,0 @@ -// Copyright by Contributors -// implementations in ctypes -#include -#include -#include -#include "rabit/rabit.h" -#include "rabit/c_api.h" - -#include "../../src/c_api/c_api_error.h" - -namespace rabit { -namespace c_api { -// helper use to avoid BitOR operator -template -struct FHelper { - static void - Allreduce(DType *senrecvbuf_, - size_t count, - void (*prepare_fun)(void *arg), - void *prepare_arg) { - rabit::Allreduce(senrecvbuf_, count, - prepare_fun, prepare_arg); - } -}; - -template -struct FHelper { - static void - Allreduce(DType *, - size_t , - void (*)(void *arg), - void *) { - utils::Error("DataType does not support bitwise AND operation"); - } -}; - -template -struct FHelper { - static void - Allreduce(DType *, - size_t , - void (*)(void *arg), - void *) { - utils::Error("DataType does not support bitwise OR operation"); - } -}; - -template -struct FHelper { - static void - Allreduce(DType *, - size_t , - void (*)(void *arg), - void *) { - utils::Error("DataType does not support bitwise XOR operation"); - } -}; - -template -void Allreduce(void *sendrecvbuf_, - size_t count, - engine::mpi::DataType enum_dtype, - void (*prepare_fun)(void *arg), - void *prepare_arg) { - using namespace engine::mpi; // NOLINT - switch (enum_dtype) { - case kChar: - rabit::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kUChar: - rabit::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kInt: - rabit::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kUInt: - rabit::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kLong: - rabit::Allreduce - (static_cast(sendrecvbuf_), // NOLINT(*) - count, prepare_fun, prepare_arg); - return; - case kULong: - rabit::Allreduce - (static_cast(sendrecvbuf_), // NOLINT(*) - count, prepare_fun, prepare_arg); - return; - case kFloat: - FHelper::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - case kDouble: - FHelper::Allreduce - (static_cast(sendrecvbuf_), - count, prepare_fun, prepare_arg); - return; - default: utils::Error("unknown data_type"); - } -} -void Allreduce(void *sendrecvbuf, - size_t count, - engine::mpi::DataType enum_dtype, - engine::mpi::OpType enum_op, - void (*prepare_fun)(void *arg), - void *prepare_arg) { - using namespace engine::mpi; // NOLINT - switch (enum_op) { - case kMax: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kMin: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kSum: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kBitwiseAND: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kBitwiseOR: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - case kBitwiseXOR: - Allreduce - (sendrecvbuf, - count, enum_dtype, - prepare_fun, prepare_arg); - return; - default: utils::Error("unknown enum_op"); - } -} - -void Allgather(void *sendrecvbuf_, - size_t total_size, - size_t beginIndex, - size_t size_node_slice, - size_t size_prev_slice, - int enum_dtype) { - using namespace engine::mpi; // NOLINT - size_t type_size = 0; - switch (enum_dtype) { - case kChar: - type_size = sizeof(char); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kUChar: - type_size = sizeof(unsigned char); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kInt: - type_size = sizeof(int); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kUInt: - type_size = sizeof(unsigned); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kLong: - type_size = sizeof(int64_t); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kULong: - type_size = sizeof(uint64_t); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kFloat: - type_size = sizeof(float); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - case kDouble: - type_size = sizeof(double); - rabit::Allgather(static_cast(sendrecvbuf_), total_size * type_size, - beginIndex * type_size, (beginIndex + size_node_slice) * type_size, - size_prev_slice * type_size); - break; - default: utils::Error("unknown data_type"); - } -} - -// wrapper for serialization -struct ReadWrapper : public Serializable { - std::string *p_str; - explicit ReadWrapper(std::string *p_str) - : p_str(p_str) {} - void Load(Stream *fi) override { - uint64_t sz; - utils::Assert(fi->Read(&sz, sizeof(sz)) != 0, - "Read pickle string"); - p_str->resize(sz); - if (sz != 0) { - utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0, - "Read pickle string"); - } - } - void Save(Stream *) const override { - utils::Error("not implemented"); - } -}; - -struct WriteWrapper : public Serializable { - const char *data; - size_t length; - explicit WriteWrapper(const char *data, - size_t length) - : data(data), length(length) { - } - void Load(Stream *) override { - utils::Error("not implemented"); - } - void Save(Stream *fo) const override { - uint64_t sz = static_cast(length); - fo->Write(&sz, sizeof(sz)); - fo->Write(data, length * sizeof(char)); - } -}; -} // namespace c_api -} // namespace rabit - -RABIT_DLL bool RabitInit(int argc, char *argv[]) { - auto ret = rabit::Init(argc, argv); - if (!ret) { - XGBAPISetLastError("Failed to initialize RABIT."); - } - return ret; -} - -RABIT_DLL int RabitFinalize() { - auto ret = rabit::Finalize(); - if (!ret) { - XGBAPISetLastError("Failed to shutdown RABIT worker."); - } - return static_cast(ret); -} - -RABIT_DLL int RabitGetRingPrevRank() { - return rabit::GetRingPrevRank(); -} - -RABIT_DLL int RabitGetRank() { - return rabit::GetRank(); -} - -RABIT_DLL int RabitGetWorldSize() { - return rabit::GetWorldSize(); -} - -RABIT_DLL int RabitIsDistributed() { - return rabit::IsDistributed(); -} - -RABIT_DLL int RabitTrackerPrint(const char *msg) { - API_BEGIN() - std::string m(msg); - rabit::TrackerPrint(m); - API_END() -} - -RABIT_DLL void RabitGetProcessorName(char *out_name, - rbt_ulong *out_len, - rbt_ulong max_len) { - std::string s = rabit::GetProcessorName(); - if (s.length() > max_len) { - s.resize(max_len - 1); - } - strcpy(out_name, s.c_str()); // NOLINT(*) - *out_len = static_cast(s.length()); -} - -RABIT_DLL int RabitBroadcast(void *sendrecv_data, - rbt_ulong size, int root) { - API_BEGIN() - rabit::Broadcast(sendrecv_data, size, root); - API_END() -} - -RABIT_DLL int RabitAllgather(void *sendrecvbuf_, size_t total_size, - size_t beginIndex, size_t size_node_slice, - size_t size_prev_slice, int enum_dtype) { - API_BEGIN() - rabit::c_api::Allgather( - sendrecvbuf_, total_size, beginIndex, size_node_slice, size_prev_slice, - static_cast(enum_dtype)); - API_END() -} - -RABIT_DLL int RabitAllreduce(void *sendrecvbuf, size_t count, int enum_dtype, - int enum_op, void (*prepare_fun)(void *arg), - void *prepare_arg) { - API_BEGIN() - rabit::c_api::Allreduce(sendrecvbuf, count, - static_cast(enum_dtype), - static_cast(enum_op), - prepare_fun, prepare_arg); - API_END() -} - -RABIT_DLL int RabitVersionNumber() { - return rabit::VersionNumber(); -} - -RABIT_DLL int RabitLinkTag() { - return 0; -} diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 0f4748bfec27..45160baea51f 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2024 by XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors */ #include "xgboost/c_api.h" @@ -15,9 +15,9 @@ #include // for pair #include // for vector -#include "../collective/communicator-inl.h" // for Allreduce, Broadcast, Finalize, GetProcessor... #include "../common/api_entry.h" // for XGBAPIThreadLocalEntry #include "../common/charconv.h" // for from_chars, to_chars, NumericLimits, from_ch... +#include "../common/error_msg.h" // for NoFederated #include "../common/hist_util.h" // for HistogramCuts #include "../common/io.h" // for FileExtension, LoadSequentialFile, MemoryBuf... #include "../common/threading_utils.h" // for OmpGetNumThreads, ParallelFor @@ -27,11 +27,10 @@ #include "../data/simple_dmatrix.h" // for SimpleDMatrix #include "c_api_error.h" // for xgboost_CHECK_C_ARG_PTR, API_END, API_BEGIN #include "c_api_utils.h" // for RequiredArg, OptionalArg, GetMissing, CastDM... -#include "dmlc/base.h" // for BeginPtr, DMLC_ATTRIBUTE_UNUSED +#include "dmlc/base.h" // for BeginPtr #include "dmlc/io.h" // for Stream #include "dmlc/parameter.h" // for FieldAccessEntry, FieldEntry, ParamManager #include "dmlc/thread_local.h" // for ThreadLocalStore -#include "rabit/c_api.h" // for RabitLinkTag #include "xgboost/base.h" // for bst_ulong, bst_float, GradientPair, bst_feat... #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for DMatrix, MetaInfo, DataType, ExtSparsePage @@ -46,10 +45,6 @@ #include "xgboost/string_view.h" // for StringView, operator<< #include "xgboost/version_config.h" // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR, XGBOOS... -#if defined(XGBOOST_USE_FEDERATED) -#include "../../plugin/federated/federated_server.h" -#endif - using namespace xgboost; // NOLINT(*); XGB_DLL void XGBoostVersion(int* major, int* minor, int* patch) { @@ -614,8 +609,8 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(field); - auto const& p_fmat = *static_cast *>(handle); - p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len); + auto const &p_fmat = *static_cast *>(handle); + p_fmat->SetInfo(field, linalg::Make1dInterface(info, len)); API_END(); } @@ -634,8 +629,9 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(field); + LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface"); auto const &p_fmat = *static_cast *>(handle); - p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len); + p_fmat->SetInfo(field, linalg::Make1dInterface(info, len)); API_END(); } @@ -679,19 +675,52 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void xgboost::bst_ulong size, int type) { API_BEGIN(); CHECK_HANDLE(); + LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGDMatrixSetInfoFromInterface"); auto const &p_fmat = *static_cast *>(handle); CHECK(type >= 1 && type <= 4); xgboost_CHECK_C_ARG_PTR(field); - p_fmat->SetInfo(field, data, static_cast(type), size); - API_END(); -} -XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) { - API_BEGIN(); - CHECK_HANDLE(); - LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead."; - auto const &p_fmat = *static_cast *>(handle); - p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len); + Context ctx; + auto dtype = static_cast(type); + std::string str; + auto proc = [&](auto cast_d_ptr) { + using T = std::remove_pointer_t; + auto t = linalg::TensorView( + common::Span{cast_d_ptr, static_cast::index_type>(size)}, + {size}, DeviceOrd::CPU()); + CHECK(t.CContiguous()); + Json iface{linalg::ArrayInterface(t)}; + CHECK(ArrayInterface<1>{iface}.is_contiguous); + str = Json::Dump(iface); + return str; + }; + + // Legacy code using XGBoost dtype, which is a small subset of array interface types. + switch (dtype) { + case xgboost::DataType::kFloat32: { + auto cast_ptr = reinterpret_cast(data); + p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr)); + break; + } + case xgboost::DataType::kDouble: { + auto cast_ptr = reinterpret_cast(data); + p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt32: { + auto cast_ptr = reinterpret_cast(data); + p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr)); + break; + } + case xgboost::DataType::kUInt64: { + auto cast_ptr = reinterpret_cast(data); + p_fmat->Info().SetInfo(ctx, field, proc(cast_ptr)); + break; + } + default: + LOG(FATAL) << "Unknown data type" << static_cast(dtype); + } + API_END(); } @@ -987,7 +1016,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle, DMatrixHandle dtrain, bs bst_float *hess, xgboost::bst_ulong len) { API_BEGIN(); CHECK_HANDLE(); - error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter"); + LOG(WARNING) << error::DeprecatedFunc(__func__, "2.1.0", "XGBoosterTrainOneIter"); auto *learner = static_cast(handle); auto ctx = learner->Ctx()->MakeCPU(); @@ -1725,76 +1754,3 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *config, *out_features = dmlc::BeginPtr(feature_names_c); API_END(); } - -XGB_DLL int XGCommunicatorInit(char const* json_config) { - API_BEGIN(); - xgboost_CHECK_C_ARG_PTR(json_config); - Json config{Json::Load(StringView{json_config})}; - collective::Init(config); - API_END(); -} - -XGB_DLL int XGCommunicatorFinalize() { - API_BEGIN(); - collective::Finalize(); - API_END(); -} - -XGB_DLL int XGCommunicatorGetRank(void) { - return collective::GetRank(); -} - -XGB_DLL int XGCommunicatorGetWorldSize(void) { - return collective::GetWorldSize(); -} - -XGB_DLL int XGCommunicatorIsDistributed(void) { - return collective::IsDistributed(); -} - -XGB_DLL int XGCommunicatorPrint(char const *message) { - API_BEGIN(); - collective::Print(message); - API_END(); -} - -XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) { - API_BEGIN(); - auto& local = *GlobalConfigAPIThreadLocalStore::Get(); - local.ret_str = collective::GetProcessorName(); - xgboost_CHECK_C_ARG_PTR(name_str); - *name_str = local.ret_str.c_str(); - API_END(); -} - -XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) { - API_BEGIN(); - collective::Broadcast(send_receive_buffer, size, root); - API_END(); -} - -XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, - int enum_op) { - API_BEGIN(); - collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op); - API_END(); -} - -#if defined(XGBOOST_USE_FEDERATED) -XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path, - char const *server_cert_path, char const *client_cert_path) { - API_BEGIN(); - federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path); - API_END(); -} - -// Run a server without SSL for local testing. -XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) { - API_BEGIN(); - federated::RunInsecureServer(port, world_size); - API_END(); -} -#endif - -// force link rabit -static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag(); diff --git a/src/c_api/c_api_error.cc b/src/c_api/c_api_error.cc index 10e864c806c4..59dfb8854dc3 100644 --- a/src/c_api/c_api_error.cc +++ b/src/c_api/c_api_error.cc @@ -1,22 +1,28 @@ -/*! - * Copyright (c) 2015 by Contributors +/** + * Copyright 2015-2023, XGBoost Contributors * \file c_api_error.cc * \brief C error handling */ +#include "./c_api_error.h" + #include + #include "xgboost/c_api.h" -#include "./c_api_error.h" +#include "../collective/comm.h" +#include "../collective/comm_group.h" struct XGBAPIErrorEntry { std::string last_error; + std::int32_t code{-1}; }; using XGBAPIErrorStore = dmlc::ThreadLocalStore; -XGB_DLL const char *XGBGetLastError() { - return XGBAPIErrorStore::Get()->last_error.c_str(); -} +XGB_DLL const char* XGBGetLastError() { return XGBAPIErrorStore::Get()->last_error.c_str(); } void XGBAPISetLastError(const char* msg) { XGBAPIErrorStore::Get()->last_error = msg; + XGBAPIErrorStore::Get()->code = -1; } + +XGB_DLL int XGBGetLastErrorCode() { return XGBAPIErrorStore::Get()->code; } diff --git a/src/c_api/c_api_error.h b/src/c_api/c_api_error.h index 11c4403847e0..0ad4ac073dbd 100644 --- a/src/c_api/c_api_error.h +++ b/src/c_api/c_api_error.h @@ -10,6 +10,7 @@ #include #include "c_api_utils.h" +#include "xgboost/collective/result.h" /*! \brief macro to guard beginning and end section of all functions */ #ifdef LOG_CAPI_INVOCATION @@ -30,7 +31,7 @@ #define API_END() \ } catch (dmlc::Error & _except_) { \ return XGBAPIHandleException(_except_); \ - } catch (std::exception const &_except_) { \ + } catch (std::exception const& _except_) { \ return XGBAPIHandleException(dmlc::Error(_except_.what())); \ } \ return 0; // NOLINT(*) @@ -48,7 +49,7 @@ void XGBAPISetLastError(const char* msg); * \param e the exception * \return the return value of API after exception is handled */ -inline int XGBAPIHandleException(const dmlc::Error &e) { +inline int XGBAPIHandleException(const dmlc::Error& e) { XGBAPISetLastError(e.what()); return -1; } diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 95efb5b9d747..04b0fc0072af 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -1,17 +1,18 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #ifndef XGBOOST_C_API_C_API_UTILS_H_ #define XGBOOST_C_API_C_API_UTILS_H_ -#include -#include -#include -#include // for shared_ptr -#include // for string -#include // for make_tuple -#include // for move -#include +#include // for min +#include // for size_t +#include // for multiplies +#include // for shared_ptr +#include // for accumulate +#include // for string +#include // for make_tuple +#include // for move +#include // for vector #include "../common/json_utils.h" // for TypeCheck #include "xgboost/c_api.h" diff --git a/src/c_api/coll_c_api.cc b/src/c_api/coll_c_api.cc index 01713dbad419..1da22610367b 100644 --- a/src/c_api/coll_c_api.cc +++ b/src/c_api/coll_c_api.cc @@ -1,16 +1,23 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include // for seconds -#include // for size_t #include // for future #include // for unique_ptr #include // for string +#include // for sleep_for #include // for is_same_v, remove_pointer_t #include // for pair -#include "../collective/tracker.h" // for RabitTracker -#include "c_api_error.h" // for API_BEGIN +#include "../collective/allgather.h" // for Allgather +#include "../collective/allreduce.h" // for Allreduce +#include "../collective/broadcast.h" // for Broadcast +#include "../collective/comm.h" // for DefaultTimeoutSec +#include "../collective/comm_group.h" // for GlobalCommGroup +#include "../collective/communicator-inl.h" // for GetProcessorName +#include "../collective/tracker.h" // for RabitTracker +#include "../common/timer.h" // for Timer +#include "c_api_error.h" // for API_BEGIN #include "xgboost/c_api.h" #include "xgboost/collective/result.h" // for Result #include "xgboost/json.h" // for Json @@ -18,15 +25,41 @@ #if defined(XGBOOST_USE_FEDERATED) #include "../../plugin/federated/federated_tracker.h" // for FederatedTracker -#else -#include "../common/error_msg.h" // for NoFederated #endif +namespace xgboost::collective { +void Allreduce(void *send_receive_buffer, std::size_t count, std::int32_t data_type, int op) { + Context ctx; + DispatchDType(static_cast(data_type), [&](auto t) { + using T = decltype(t); + auto data = linalg::MakeTensorView( + &ctx, common::Span{static_cast(send_receive_buffer), count}, count); + auto rc = Allreduce(&ctx, *GlobalCommGroup(), data, static_cast(op)); + SafeColl(rc); + }); +} + +void Broadcast(void *send_receive_buffer, std::size_t size, int root) { + Context ctx; + auto rc = Broadcast(&ctx, *GlobalCommGroup(), + linalg::MakeVec(static_cast(send_receive_buffer), size), root); + SafeColl(rc); +} + +void Allgather(void *send_receive_buffer, std::size_t size) { + Context ctx; + auto const &comm = GlobalCommGroup(); + auto rc = Allgather(&ctx, *comm, + linalg::MakeVec(reinterpret_cast(send_receive_buffer), size)); + SafeColl(rc); +} +} // namespace xgboost::collective + using namespace xgboost; // NOLINT namespace { using TrackerHandleT = - std::pair, std::shared_future>; + std::pair, std::shared_future>; TrackerHandleT *GetTrackerHandle(TrackerHandle handle) { xgboost_CHECK_C_ARG_PTR(handle); @@ -40,17 +73,30 @@ struct CollAPIEntry { }; using CollAPIThreadLocalStore = dmlc::ThreadLocalStore; -void WaitImpl(TrackerHandleT *ptr) { - std::chrono::seconds wait_for{100}; +void WaitImpl(TrackerHandleT *ptr, std::chrono::seconds timeout) { + constexpr std::int64_t kDft{collective::DefaultTimeoutSec()}; + std::chrono::seconds wait_for{collective::HasTimeout(timeout) ? std::min(kDft, timeout.count()) + : kDft}; + + common::Timer timer; + timer.Start(); + + auto ref = ptr->first; // hold a reference to that free don't delete it while waiting. + auto fut = ptr->second; while (fut.valid()) { auto res = fut.wait_for(wait_for); CHECK(res != std::future_status::deferred); + if (res == std::future_status::ready) { auto const &rc = ptr->second.get(); - CHECK(rc.OK()) << rc.Report(); + collective::SafeColl(rc); break; } + + if (timer.Duration() > timeout && collective::HasTimeout(timeout)) { + collective::SafeColl(collective::Fail("Timeout waiting for the tracker.")); + } } } } // namespace @@ -62,15 +108,15 @@ XGB_DLL int XGTrackerCreate(char const *config, TrackerHandle *handle) { Json jconfig = Json::Load(config); auto type = RequiredArg(jconfig, "dmlc_communicator", __func__); - std::unique_ptr tptr; + std::shared_ptr tptr; if (type == "federated") { #if defined(XGBOOST_USE_FEDERATED) - tptr = std::make_unique(jconfig); + tptr = std::make_shared(jconfig); #else LOG(FATAL) << error::NoFederated(); #endif // defined(XGBOOST_USE_FEDERATED) } else if (type == "rabit") { - tptr = std::make_unique(jconfig); + tptr = std::make_shared(jconfig); } else { LOG(FATAL) << "Unknown communicator:" << type; } @@ -93,7 +139,7 @@ XGB_DLL int XGTrackerWorkerArgs(TrackerHandle handle, char const **args) { API_END(); } -XGB_DLL int XGTrackerRun(TrackerHandle handle) { +XGB_DLL int XGTrackerRun(TrackerHandle handle, char const *) { API_BEGIN(); auto *ptr = GetTrackerHandle(handle); CHECK(!ptr->second.valid()) << "Tracker is already running."; @@ -101,19 +147,107 @@ XGB_DLL int XGTrackerRun(TrackerHandle handle) { API_END(); } -XGB_DLL int XGTrackerWait(TrackerHandle handle, char const *config) { +XGB_DLL int XGTrackerWaitFor(TrackerHandle handle, char const *config) { API_BEGIN(); auto *ptr = GetTrackerHandle(handle); xgboost_CHECK_C_ARG_PTR(config); auto jconfig = Json::Load(StringView{config}); - WaitImpl(ptr); + // Internally, 0 indicates no timeout, which is the default since we don't want to + // interrupt the model training. + xgboost_CHECK_C_ARG_PTR(config); + auto timeout = OptionalArg(jconfig, "timeout", std::int64_t{0}); + WaitImpl(ptr, std::chrono::seconds{timeout}); API_END(); } XGB_DLL int XGTrackerFree(TrackerHandle handle) { API_BEGIN(); + using namespace std::chrono_literals; // NOLINT auto *ptr = GetTrackerHandle(handle); - WaitImpl(ptr); + ptr->first->Stop(); + // The wait is not necessary since we just called stop, just reusing the function to do + // any potential cleanups. + WaitImpl(ptr, ptr->first->Timeout()); + common::Timer timer; + timer.Start(); + // Make sure no one else is waiting on the tracker. + while (!ptr->first.unique()) { + auto ela = timer.Duration().count(); + if (collective::HasTimeout(ptr->first->Timeout()) && ela > ptr->first->Timeout().count()) { + LOG(WARNING) << "Time out " << ptr->first->Timeout().count() + << " seconds reached for TrackerFree, killing the tracker."; + break; + } + std::this_thread::sleep_for(64ms); + } delete ptr; API_END(); } + +XGB_DLL int XGCommunicatorInit(char const *json_config) { + API_BEGIN(); + xgboost_CHECK_C_ARG_PTR(json_config); + Json config{Json::Load(StringView{json_config})}; + collective::GlobalCommGroupInit(config); + API_END(); +} + +XGB_DLL int XGCommunicatorFinalize(void) { + API_BEGIN(); + collective::GlobalCommGroupFinalize(); + API_END(); +} + +XGB_DLL int XGCommunicatorGetRank(void) { + API_BEGIN(); + return collective::GetRank(); + API_END(); +} + +XGB_DLL int XGCommunicatorGetWorldSize(void) { return collective::GetWorldSize(); } + +XGB_DLL int XGCommunicatorIsDistributed(void) { return collective::IsDistributed(); } + +XGB_DLL int XGCommunicatorPrint(char const *message) { + API_BEGIN(); + collective::Print(message); + API_END(); +} + +XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) { + API_BEGIN(); + auto &local = *CollAPIThreadLocalStore::Get(); + local.ret_str = collective::GetProcessorName(); + xgboost_CHECK_C_ARG_PTR(name_str); + *name_str = local.ret_str.c_str(); + API_END(); +} + +XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) { + API_BEGIN(); + collective::Broadcast(send_receive_buffer, size, root); + API_END(); +} + +XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype, + int enum_op) { + API_BEGIN(); + collective::Allreduce(send_receive_buffer, count, enum_dtype, enum_op); + API_END(); +} + +// Not exposed to the public since the previous implementation didn't and we don't want to +// add unnecessary communicator API to a machine learning library. +XGB_DLL int XGCommunicatorAllgather(void *send_receive_buffer, size_t count) { + API_BEGIN(); + collective::Allgather(send_receive_buffer, count); + API_END(); +} + +// Not yet exposed to the public, error recovery is still WIP. +XGB_DLL int XGCommunicatorSignalError() { + API_BEGIN(); + auto msg = XGBGetLastError(); + SafeColl(xgboost::collective::GlobalCommGroup()->SignalError(xgboost::collective::Fail(msg))); + API_END() +} diff --git a/src/cli_main.cc b/src/cli_main.cc index 276d67da8db4..54a3450276f4 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -22,7 +22,6 @@ #include #include #include -#include "collective/communicator-inl.h" #include "common/common.h" #include "common/config.h" #include "common/io.h" @@ -193,10 +192,6 @@ class CLI { void CLITrain() { const double tstart_data_load = dmlc::GetTime(); - if (collective::IsDistributed()) { - std::string pname = collective::GetProcessorName(); - LOG(CONSOLE) << "start " << pname << ":" << collective::GetRank(); - } // load in data. std::shared_ptr dtrain(DMatrix::Load( param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), @@ -235,15 +230,9 @@ class CLI { version += 1; } std::string res = learner_->EvalOneIter(i, eval_datasets, eval_data_names); - if (collective::IsDistributed()) { - if (collective::GetRank() == 0) { - LOG(TRACKER) << res; - } - } else { - LOG(CONSOLE) << res; - } - if (param_.save_period != 0 && (i + 1) % param_.save_period == 0 && - collective::GetRank() == 0) { + LOG(CONSOLE) << res; + + if (param_.save_period != 0 && (i + 1) % param_.save_period == 0) { std::ostringstream os; os << param_.model_dir << '/' << std::setfill('0') << std::setw(4) << i + 1 << ".model"; @@ -256,8 +245,7 @@ class CLI { << " sec"; // always save final round if ((param_.save_period == 0 || - param_.num_round % param_.save_period != 0) && - collective::GetRank() == 0) { + param_.num_round % param_.save_period != 0)) { std::ostringstream os; if (param_.model_out == CLIParam::kNull) { os << param_.model_dir << '/' << std::setfill('0') << std::setw(4) @@ -465,13 +453,6 @@ class CLI { } } - // Initialize the collective communicator. - Json json{JsonObject()}; - for (auto& kv : cfg) { - json[kv.first] = String(kv.second); - } - collective::Init(json); - param_.Configure(cfg); } @@ -507,10 +488,6 @@ class CLI { } return 0; } - - ~CLI() { - collective::Finalize(); - } }; } // namespace xgboost diff --git a/src/collective/aggregator.cuh b/src/collective/aggregator.cuh index 66766470b9d2..d85e328aacdb 100644 --- a/src/collective/aggregator.cuh +++ b/src/collective/aggregator.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023 by XGBoost contributors + * Copyright 2023-2024, XGBoost contributors * * Higher level functions built on top the Communicator API, taking care of behavioral differences * between row-split vs column-split distributed training, and horizontal vs vertical federated @@ -13,7 +13,8 @@ #include #include -#include "communicator-inl.cuh" +#include "allreduce.h" +#include "xgboost/collective/result.h" // for Result namespace xgboost::collective { @@ -24,15 +25,17 @@ namespace xgboost::collective { * column-wise (vertically), the original values are returned. * * @tparam T The type of the values. + * * @param info MetaInfo about the DMatrix. - * @param device The device id. * @param values Pointer to the inputs to sum. * @param size Number of values to sum. */ -template -void GlobalSum(MetaInfo const& info, DeviceOrd device, T* values, size_t size) { +template +[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, + linalg::TensorView values) { if (info.IsRowSplit()) { - collective::AllReduce(device.ordinal, values, size); + return collective::Allreduce(ctx, values, collective::Op::kSum); } + return Success(); } } // namespace xgboost::collective diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index 8a5b31c36546..a328a61203e1 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -11,11 +11,44 @@ #include #include +#include "allreduce.h" +#include "broadcast.h" +#include "comm.h" #include "communicator-inl.h" #include "xgboost/collective/result.h" // for Result #include "xgboost/data.h" // for MetaINfo namespace xgboost::collective { +namespace detail { +template +[[nodiscard]] Result TryApplyWithLabels(Context const* ctx, Fn&& fn) { + std::string msg; + if (collective::GetRank() == 0) { + try { + fn(); + } catch (dmlc::Error const& e) { + msg = e.what(); + } + } + std::size_t msg_size{msg.size()}; + auto rc = Success() << [&] { + auto rc = collective::Broadcast(ctx, linalg::MakeVec(&msg_size, 1), 0); + return rc; + } << [&] { + if (msg_size > 0) { + msg.resize(msg_size); + return collective::Broadcast(ctx, linalg::MakeVec(msg.data(), msg.size()), 0); + } + return Success(); + } << [&] { + if (msg_size > 0) { + LOG(FATAL) << msg; + } + return Success(); + }; + return rc; +} +} // namespace detail /** * @brief Apply the given function where the labels are. @@ -30,29 +63,19 @@ namespace xgboost::collective { * @param size The size of the buffer. * @param function The function used to calculate the results. */ -template -void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size, - FN&& function) { +template +void ApplyWithLabels(Context const* ctx, MetaInfo const& info, void* buffer, std::size_t size, + Fn&& fn) { if (info.IsVerticalFederated()) { - // We assume labels are only available on worker 0, so the calculation is done there and result - // broadcast to other workers. - std::string message; - if (collective::GetRank() == 0) { - try { - std::forward(function)(); - } catch (dmlc::Error& e) { - message = e.what(); - } - } - - collective::Broadcast(&message, 0); - if (message.empty()) { - collective::Broadcast(buffer, size, 0); - } else { - LOG(FATAL) << &message[0]; - } + auto rc = detail::TryApplyWithLabels(ctx, fn) << [&] { + // We assume labels are only available on worker 0, so the calculation is done there and + // result broadcast to other workers. + return collective::Broadcast( + ctx, linalg::MakeVec(reinterpret_cast(buffer), size), 0); + }; + SafeColl(rc); } else { - std::forward(function)(); + std::forward(fn)(); } } @@ -69,37 +92,24 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si * @param result The HostDeviceVector storing the results. * @param function The function used to calculate the results. */ -template -void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector* result, - Function&& function) { +template +void ApplyWithLabels(Context const* ctx, MetaInfo const& info, HostDeviceVector* result, + Fn&& fn) { if (info.IsVerticalFederated()) { // We assume labels are only available on worker 0, so the calculation is done there and result // broadcast to other workers. - std::string message; - if (collective::GetRank() == 0) { - try { - std::forward(function)(); - } catch (dmlc::Error& e) { - message = e.what(); - } - } - - collective::Broadcast(&message, 0); - if (!message.empty()) { - LOG(FATAL) << &message[0]; - return; - } + auto rc = detail::TryApplyWithLabels(ctx, fn); - std::size_t size{}; - if (collective::GetRank() == 0) { - size = result->Size(); - } - collective::Broadcast(&size, sizeof(std::size_t), 0); - - result->Resize(size); - collective::Broadcast(result->HostPointer(), size * sizeof(T), 0); + std::size_t size{result->Size()}; + rc = std::move(rc) << [&] { + return collective::Broadcast(ctx, linalg::MakeVec(&size, 1), 0); + } << [&] { + result->Resize(size); + return collective::Broadcast(ctx, linalg::MakeVec(result->HostPointer(), size), 0); + }; + SafeColl(rc); } else { - std::forward(function)(); + std::forward(fn)(); } } @@ -115,11 +125,12 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector* * @return The global max of the input. */ template -std::enable_if_t, T> GlobalMax(Context const*, +std::enable_if_t, T> GlobalMax(Context const* ctx, MetaInfo const& info, T value) { if (info.IsRowSplit()) { - collective::Allreduce(&value, 1); + auto rc = collective::Allreduce(ctx, linalg::MakeVec(&value, 1), collective::Op::kMax); + SafeColl(rc); } return value; } @@ -136,19 +147,14 @@ std::enable_if_t, T> GlobalMax(Context co * @param size Number of values to sum. */ template -[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info, +[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, linalg::TensorView values) { if (info.IsRowSplit()) { - collective::Allreduce(values.Values().data(), values.Size()); + return collective::Allreduce(ctx, values, collective::Op::kSum); } return Success(); } -template -[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) { - return GlobalSum(ctx, info, values->data(), values->size()); -} - /** * @brief Find the global ratio of the given two values across all workers. * @@ -165,7 +171,7 @@ template T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) { std::array results{dividend, divisor}; auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size())); - collective::SafeColl(rc); + SafeColl(rc); std::tie(dividend, divisor) = std::tuple_cat(results); if (divisor <= 0) { return std::numeric_limits::quiet_NaN(); diff --git a/src/collective/allgather.cc b/src/collective/allgather.cc index 148cb6cd2882..c2c7a500f0f3 100644 --- a/src/collective/allgather.cc +++ b/src/collective/allgather.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "allgather.h" @@ -7,6 +7,7 @@ #include // for size_t #include // for int8_t, int32_t, int64_t #include // for shared_ptr +#include // for move #include "broadcast.h" #include "comm.h" // for Comm, Channel @@ -29,18 +30,24 @@ Result RingAllgather(Comm const& comm, common::Span data, std::size auto rc = Success() << [&] { auto send_rank = (rank + world - r + worker_off) % world; auto send_off = send_rank * segment_size; - send_off = std::min(send_off, data.size_bytes()); - auto send_seg = data.subspan(send_off, std::min(segment_size, data.size_bytes() - send_off)); + bool is_last_segment = send_rank == (world - 1); + auto send_nbytes = is_last_segment ? (data.size_bytes() - send_off) : segment_size; + auto send_seg = data.subspan(send_off, send_nbytes); + CHECK_NE(send_seg.size(), 0); return next_ch->SendAll(send_seg.data(), send_seg.size_bytes()); } << [&] { auto recv_rank = (rank + world - r - 1 + worker_off) % world; auto recv_off = recv_rank * segment_size; - recv_off = std::min(recv_off, data.size_bytes()); - auto recv_seg = data.subspan(recv_off, std::min(segment_size, data.size_bytes() - recv_off)); + bool is_last_segment = recv_rank == (world - 1); + auto recv_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : segment_size; + auto recv_seg = data.subspan(recv_off, recv_nbytes); + CHECK_NE(recv_seg.size(), 0); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); - } << [&] { return prev_ch->Block(); }; + } << [&] { + return comm.Block(); + }; if (!rc.OK()) { - return rc; + return Fail("Ring allgather failed, current iteration:" + std::to_string(r), std::move(rc)); } } @@ -54,7 +61,8 @@ Result BroadcastAllgatherV(Comm const& comm, common::Span si auto as_bytes = sizes[r]; auto rc = Broadcast(comm, recv.subspan(offset, as_bytes), r); if (!rc.OK()) { - return rc; + return Fail("Broadcast AllgatherV failed, current iteration:" + std::to_string(r), + std::move(rc)); } offset += as_bytes; } @@ -91,12 +99,57 @@ namespace detail { auto recv_size = sizes[recv_rank]; auto recv_seg = erased_result.subspan(recv_off, recv_size); return prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes()); - } << [&] { return prev_ch->Block(); }; + } << [&] { + return prev_ch->Block(); + }; if (!rc.OK()) { - return rc; + return Fail("Ring AllgatherV failed, current iterataion:" + std::to_string(r), std::move(rc)); } } return comm.Block(); } } // namespace detail + +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, CommGroup const& comm, std::vector> const& input) { + auto n_inputs = input.size(); + std::vector sizes(n_inputs); + std::transform(input.cbegin(), input.cend(), sizes.begin(), + [](auto const& vec) { return vec.size(); }); + + std::vector recv_segments(comm.World() + 1, 0); + + HostDeviceVector recv; + auto rc = + AllgatherV(ctx, comm, linalg::MakeVec(sizes.data(), sizes.size()), &recv_segments, &recv); + SafeColl(rc); + + auto global_sizes = common::RestoreType(recv.ConstHostSpan()); + std::vector offset(global_sizes.size() + 1); + offset[0] = 0; + for (std::size_t i = 1; i < offset.size(); i++) { + offset[i] = offset[i - 1] + global_sizes[i - 1]; + } + + std::vector collected; + for (auto const& vec : input) { + collected.insert(collected.end(), vec.cbegin(), vec.cend()); + } + rc = AllgatherV(ctx, comm, linalg::MakeVec(collected.data(), collected.size()), &recv_segments, + &recv); + SafeColl(rc); + auto out = common::RestoreType(recv.ConstHostSpan()); + + std::vector> result; + for (std::size_t i = 1; i < offset.size(); ++i) { + std::vector local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]); + result.emplace_back(std::move(local)); + } + return result; +} + +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, std::vector> const& input) { + return VectorAllgatherV(ctx, *GlobalCommGroup(), input); +} } // namespace xgboost::collective diff --git a/src/collective/allgather.h b/src/collective/allgather.h index 4f13014be618..ca44c3916cc3 100644 --- a/src/collective/allgather.h +++ b/src/collective/allgather.h @@ -1,25 +1,27 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for size_t #include // for int32_t #include // for shared_ptr #include // for accumulate +#include // for string #include // for remove_cv_t #include // for vector -#include "../common/type.h" // for EraseType +#include "../common/type.h" // for EraseType #include "comm.h" // for Comm, Channel +#include "comm_group.h" // for CommGroup #include "xgboost/collective/result.h" // for Result -#include "xgboost/linalg.h" -#include "xgboost/span.h" // for Span +#include "xgboost/linalg.h" // for MakeVec +#include "xgboost/span.h" // for Span namespace xgboost::collective { namespace cpu_impl { /** * @param worker_off Segment offset. For example, if the rank 2 worker specifies - * worker_off = 1, then it owns the third segment. + * worker_off = 1, then it owns the third segment (2 + 1). */ [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, std::int32_t worker_off, @@ -51,8 +53,10 @@ inline void AllgatherVOffset(common::Span sizes, } // namespace detail template -[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t size) { - auto n_bytes = sizeof(T) * size; +[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data) { + // This function is also used for ring allreduce, hence we allow the last segment to be + // larger due to round-down. + auto n_bytes_per_segment = data.size_bytes() / comm.World(); auto erased = common::EraseType(data); auto rank = comm.Rank(); @@ -61,7 +65,7 @@ template auto prev_ch = comm.Chan(prev); auto next_ch = comm.Chan(next); - auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch); + auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes_per_segment, 0, prev_ch, next_ch); if (!rc.OK()) { return rc; } @@ -76,7 +80,7 @@ template std::vector sizes(world, 0); sizes[rank] = data.size_bytes(); - auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1); + auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}); if (!rc.OK()) { return rc; } @@ -98,4 +102,115 @@ template return detail::RingAllgatherV(comm, sizes, s_segments, erased_result); } + +template +[[nodiscard]] Result Allgather(Context const* ctx, CommGroup const& comm, + linalg::VectorView data) { + if (!comm.IsDistributed()) { + return Success(); + } + CHECK(data.Contiguous()); + auto erased = common::EraseType(data.Values()); + + auto const& cctx = comm.Ctx(ctx, data.Device()); + auto backend = comm.Backend(data.Device()); + return backend->Allgather(cctx, erased); +} + +/** + * @brief Gather all data from all workers. + * + * @param data The input and output buffer, needs to be pre-allocated by the caller. + */ +template +[[nodiscard]] Result Allgather(Context const* ctx, linalg::VectorView data) { + auto const& cg = *GlobalCommGroup(); + if (data.Size() % cg.World() != 0) { + return Fail("The total number of elements should be multiple of the number of workers."); + } + return Allgather(ctx, cg, data); +} + +template +[[nodiscard]] Result AllgatherV(Context const* ctx, CommGroup const& comm, + linalg::VectorView data, + std::vector* recv_segments, + HostDeviceVector* recv) { + if (!comm.IsDistributed()) { + return Success(); + } + std::vector sizes(comm.World(), 0); + sizes[comm.Rank()] = data.Values().size_bytes(); + auto erased_sizes = common::EraseType(common::Span{sizes.data(), sizes.size()}); + auto rc = comm.Backend(DeviceOrd::CPU()) + ->Allgather(comm.Ctx(ctx, DeviceOrd::CPU()), erased_sizes); + if (!rc.OK()) { + return rc; + } + + recv_segments->resize(sizes.size() + 1); + detail::AllgatherVOffset(sizes, common::Span{recv_segments->data(), recv_segments->size()}); + auto total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0LL); + recv->SetDevice(data.Device()); + recv->Resize(total_bytes); + + auto s_segments = common::Span{recv_segments->data(), recv_segments->size()}; + + auto backend = comm.Backend(data.Device()); + auto erased = common::EraseType(data.Values()); + + return backend->AllgatherV( + comm.Ctx(ctx, data.Device()), erased, common::Span{sizes.data(), sizes.size()}, s_segments, + data.Device().IsCUDA() ? recv->DeviceSpan() : recv->HostSpan(), AllgatherVAlgo::kBcast); +} + +/** + * @brief Allgather with variable length data. + * + * @param data The input data. + * @param recv_segments segment size for each worker. [0, 2, 5] means [0, 2) elements are + * from the first worker, [2, 5) elements are from the second one. + * @param recv The buffer storing the result. + */ +template +[[nodiscard]] Result AllgatherV(Context const* ctx, linalg::VectorView data, + std::vector* recv_segments, + HostDeviceVector* recv) { + return AllgatherV(ctx, *GlobalCommGroup(), data, recv_segments, recv); +} + +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, CommGroup const& comm, std::vector> const& input); + +/** + * @brief Gathers variable-length data from all processes and distributes it to all processes. + * + * @param inputs All the inputs from the local worker. The number of inputs can vary + * across different workers. Along with which, the size of each vector in + * the input can also vary. + * + * @return The AllgatherV result, containing vectors from all workers. + */ +[[nodiscard]] std::vector> VectorAllgatherV( + Context const* ctx, std::vector> const& input); + +/** + * @brief Gathers variable-length strings from all processes and distributes them to all processes. + * @param input Variable-length list of variable-length strings. + */ +[[nodiscard]] inline Result AllgatherStrings(std::vector const& input, + std::vector* p_result) { + std::vector> inputs(input.size()); + for (std::size_t i = 0; i < input.size(); ++i) { + inputs[i] = {input[i].cbegin(), input[i].cend()}; + } + Context ctx; + auto out = VectorAllgatherV(&ctx, *GlobalCommGroup(), inputs); + auto& result = *p_result; + result.resize(out.size()); + for (std::size_t i = 0; i < out.size(); ++i) { + result[i] = {out[i].cbegin(), out[i].cend()}; + } + return Success(); +} } // namespace xgboost::collective diff --git a/src/collective/allreduce.cc b/src/collective/allreduce.cc index 93b76355f807..3b201c99d648 100644 --- a/src/collective/allreduce.cc +++ b/src/collective/allreduce.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "allreduce.h" @@ -16,7 +16,44 @@ #include "xgboost/span.h" // for Span namespace xgboost::collective::cpu_impl { +namespace { template +Result RingAllreduceSmall(Comm const& comm, common::Span data, Func const& op) { + auto rank = comm.Rank(); + auto world = comm.World(); + + auto next_ch = comm.Chan(BootstrapNext(rank, world)); + auto prev_ch = comm.Chan(BootstrapPrev(rank, world)); + + std::vector buffer(data.size_bytes() * world, 0); + auto s_buffer = common::Span{buffer.data(), buffer.size()}; + + auto offset = data.size_bytes() * rank; + auto self = s_buffer.subspan(offset, data.size_bytes()); + std::copy_n(data.data(), data.size_bytes(), self.data()); + + auto typed = common::RestoreType(s_buffer); + auto rc = RingAllgather(comm, typed); + + if (!rc.OK()) { + return Fail("Ring allreduce small failed.", std::move(rc)); + } + auto first = s_buffer.subspan(0, data.size_bytes()); + CHECK_EQ(first.size(), data.size()); + + for (std::int32_t r = 1; r < world; ++r) { + auto offset = data.size_bytes() * r; + auto buf = s_buffer.subspan(offset, data.size_bytes()); + op(buf, first); + } + std::copy_n(first.data(), first.size(), data.data()); + + return Success(); +} +} // namespace + +template +// note that n_bytes_in_seg is calculated with round-down. Result RingScatterReduceTyped(Comm const& comm, common::Span data, std::size_t n_bytes_in_seg, Func const& op) { auto rank = comm.Rank(); @@ -27,32 +64,42 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span data, auto next_ch = comm.Chan(dst_rank); auto prev_ch = comm.Chan(src_rank); - std::vector buffer(n_bytes_in_seg, 0); + std::vector buffer(data.size_bytes() - (world - 1) * n_bytes_in_seg, -1); auto s_buf = common::Span{buffer.data(), buffer.size()}; for (std::int32_t r = 0; r < world - 1; ++r) { - // send to ring next - auto send_off = ((rank + world - r) % world) * n_bytes_in_seg; - send_off = std::min(send_off, data.size_bytes()); - auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg); - auto send_seg = data.subspan(send_off, seg_nbytes); + common::Span seg, recv_seg; + auto rc = Success() << [&] { + // send to ring next + auto send_rank = (rank + world - r) % world; + auto send_off = send_rank * n_bytes_in_seg; - auto rc = next_ch->SendAll(send_seg); - if (!rc.OK()) { - return rc; - } + bool is_last_segment = send_rank == (world - 1); + + auto seg_nbytes = is_last_segment ? data.size_bytes() - send_off : n_bytes_in_seg; + CHECK_EQ(seg_nbytes % sizeof(T), 0); - // receive from ring prev - auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg; - recv_off = std::min(recv_off, data.size_bytes()); - seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg); - CHECK_EQ(seg_nbytes % sizeof(T), 0); - auto recv_seg = data.subspan(recv_off, seg_nbytes); - auto seg = s_buf.subspan(0, recv_seg.size()); + auto send_seg = data.subspan(send_off, seg_nbytes); + return next_ch->SendAll(send_seg); + } << [&] { + // receive from ring prev + auto recv_rank = (rank + world - r - 1) % world; + auto recv_off = recv_rank * n_bytes_in_seg; - rc = std::move(rc) << [&] { return prev_ch->RecvAll(seg); } << [&] { return comm.Block(); }; + bool is_last_segment = recv_rank == (world - 1); + + auto seg_nbytes = is_last_segment ? (data.size_bytes() - recv_off) : n_bytes_in_seg; + CHECK_EQ(seg_nbytes % sizeof(T), 0); + + recv_seg = data.subspan(recv_off, seg_nbytes); + seg = s_buf.subspan(0, recv_seg.size()); + return prev_ch->RecvAll(seg); + } << [&] { + return comm.Block(); + }; if (!rc.OK()) { - return rc; + return Fail("Ring scatter reduce failed, current iteration:" + std::to_string(r), + std::move(rc)); } // accumulate to recv_seg @@ -68,6 +115,9 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons if (comm.World() == 1) { return Success(); } + if (data.size_bytes() == 0) { + return Success(); + } return DispatchDType(type, [&](auto t) { using T = decltype(t); // Divide the data into segments according to the number of workers. @@ -75,10 +125,14 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons CHECK_EQ(data.size_bytes() % n_bytes_elem, 0); auto n = data.size_bytes() / n_bytes_elem; auto world = comm.World(); - auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T); + if (n < static_cast(world)) { + return RingAllreduceSmall(comm, data, op); + } + + auto n_bytes_in_seg = (n / world) * sizeof(T); auto rc = RingScatterReduceTyped(comm, data, n_bytes_in_seg, op); if (!rc.OK()) { - return rc; + return Fail("Ring Allreduce failed.", std::move(rc)); } auto prev = BootstrapPrev(comm.Rank(), comm.World()); @@ -88,7 +142,9 @@ Result RingAllreduce(Comm const& comm, common::Span data, Func cons return std::move(rc) << [&] { return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch); - } << [&] { return comm.Block(); }; + } << [&] { + return comm.Block(); + }; }); } } // namespace xgboost::collective::cpu_impl diff --git a/src/collective/allreduce.h b/src/collective/allreduce.h index 0c94d11cc35d..3e88cca112cb 100644 --- a/src/collective/allreduce.h +++ b/src/collective/allreduce.h @@ -1,15 +1,18 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int8_t #include // for function #include // for is_invocable_v, enable_if_t +#include // for vector #include "../common/type.h" // for EraseType, RestoreType -#include "../data/array_interface.h" // for ArrayInterfaceHandler +#include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler #include "comm.h" // for Comm, RestoreType +#include "comm_group.h" // for GlobalCommGroup #include "xgboost/collective/result.h" // for Result +#include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { @@ -27,8 +30,7 @@ std::enable_if_t, common::Span> auto erased = common::EraseType(data); auto type = ToDType::kType; - auto erased_fn = [type, redop](common::Span lhs, - common::Span out) { + auto erased_fn = [redop](common::Span lhs, common::Span out) { CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction."; auto lhs_t = common::RestoreType(lhs); auto rhs_t = common::RestoreType(out); @@ -37,4 +39,40 @@ std::enable_if_t, common::Span> return cpu_impl::RingAllreduce(comm, erased, erased_fn, type); } + +template +[[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm, + linalg::TensorView data, Op op) { + if (!comm.IsDistributed()) { + return Success(); + } + CHECK(data.Contiguous()); + auto erased = common::EraseType(data.Values()); + auto type = ToDType::kType; + + auto backend = comm.Backend(data.Device()); + return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op); +} + +template +[[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView data, Op op) { + return Allreduce(ctx, *GlobalCommGroup(), data, op); +} + +/** + * @brief Specialization for std::vector. + */ +template +[[nodiscard]] Result Allreduce(Context const* ctx, std::vector* data, Op op) { + return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op); +} + +/** + * @brief Specialization for scalar value. + */ +template +[[nodiscard]] std::enable_if_t && std::is_trivial_v, Result> +Allreduce(Context const* ctx, T* data, Op op) { + return Allreduce(ctx, linalg::MakeVec(data, 1), op); +} } // namespace xgboost::collective diff --git a/src/collective/broadcast.h b/src/collective/broadcast.h index 28db83815cd4..61cab8cdd8f6 100644 --- a/src/collective/broadcast.h +++ b/src/collective/broadcast.h @@ -1,11 +1,15 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int32_t, int8_t -#include "comm.h" // for Comm -#include "xgboost/collective/result.h" // for +#include "../common/type.h" +#include "comm.h" // for Comm, EraseType +#include "comm_group.h" // for CommGroup +#include "xgboost/collective/result.h" // for Result +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for VectorView #include "xgboost/span.h" // for Span namespace xgboost::collective { @@ -23,4 +27,21 @@ template common::Span{reinterpret_cast(data.data()), n_total_bytes}; return cpu_impl::Broadcast(comm, erased, root); } + +template +[[nodiscard]] Result Broadcast(Context const* ctx, CommGroup const& comm, + linalg::VectorView data, std::int32_t root) { + if (!comm.IsDistributed()) { + return Success(); + } + CHECK(data.Contiguous()); + auto erased = common::EraseType(data.Values()); + auto backend = comm.Backend(data.Device()); + return backend->Broadcast(comm.Ctx(ctx, data.Device()), erased, root); +} + +template +[[nodiscard]] Result Broadcast(Context const* ctx, linalg::VectorView data, std::int32_t root) { + return Broadcast(ctx, *GlobalCommGroup(), data, root); +} } // namespace xgboost::collective diff --git a/src/collective/coll.cc b/src/collective/coll.cc index 755d44c905cd..b720d09b7eb9 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "coll.h" @@ -7,6 +7,7 @@ #include // for size_t #include // for int8_t, int64_t #include // for bit_and, bit_or, bit_xor, plus +#include // for string #include // for is_floating_point_v, is_same_v #include // for move @@ -37,6 +38,10 @@ bool constexpr IsFloatingPointV() { auto redop_fn = [](auto lhs, auto out, auto elem_op) { auto p_lhs = lhs.data(); auto p_out = out.data(); +#if defined(__GNUC__) || defined(__clang__) + // For the sum op, one can verify the simd by: addps %xmm15, %xmm14 +#pragma omp simd +#endif for (std::size_t i = 0; i < lhs.size(); ++i) { p_out[i] = elem_op(p_lhs[i], p_out[i]); } @@ -56,6 +61,8 @@ bool constexpr IsFloatingPointV() { return cpu_impl::RingAllreduce(comm, data, erased_fn, type); }; + std::string msg{"Floating point is not supported for bit wise collective operations."}; + auto rc = DispatchDType(type, [&](auto t) { using T = decltype(t); switch (op) { @@ -70,21 +77,21 @@ bool constexpr IsFloatingPointV() { } case Op::kBitwiseAND: { if constexpr (IsFloatingPointV()) { - return Fail("Invalid type."); + return Fail(msg); } else { return fn(std::bit_and<>{}, t); } } case Op::kBitwiseOR: { if constexpr (IsFloatingPointV()) { - return Fail("Invalid type."); + return Fail(msg); } else { return fn(std::bit_or<>{}, t); } } case Op::kBitwiseXOR: { if constexpr (IsFloatingPointV()) { - return Fail("Invalid type."); + return Fail(msg); } else { return fn(std::bit_xor<>{}, t); } @@ -101,9 +108,8 @@ bool constexpr IsFloatingPointV() { return cpu_impl::Broadcast(comm, data, root); } -[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data, - std::int64_t size) { - return RingAllgather(comm, data, size); +[[nodiscard]] Result Coll::Allgather(Comm const& comm, common::Span data) { + return RingAllgather(comm, data); } [[nodiscard]] Result Coll::AllgatherV(Comm const& comm, common::Span data, diff --git a/src/collective/coll.cu b/src/collective/coll.cu index d1b66a8ced82..433f1e49dbe5 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -1,10 +1,9 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include // for int8_t, int64_t -#include "../common/cuda_context.cuh" #include "../common/device_helpers.cuh" #include "../data/array_interface.h" #include "allgather.h" // for AllgatherVOffset @@ -162,14 +161,14 @@ ncclRedOp_t GetNCCLRedOp(Op const& op) { } << [&] { return nccl->Block(); }; } -[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span data, - std::int64_t size) { +[[nodiscard]] Result NCCLColl::Allgather(Comm const& comm, common::Span data) { if (!comm.IsDistributed()) { return Success(); } auto nccl = dynamic_cast(&comm); CHECK(nccl); auto stub = nccl->Stub(); + auto size = data.size_bytes() / comm.World(); auto send = data.subspan(comm.Rank() * size, size); return Success() << [&] { @@ -192,7 +191,7 @@ Result BroadcastAllgatherV(NCCLComm const* comm, common::Span for (std::int32_t r = 0; r < comm->World(); ++r) { auto as_bytes = sizes[r]; auto rc = stub->Broadcast(data.data(), recv.subspan(offset, as_bytes).data(), as_bytes, - ncclInt8, r, comm->Handle(), dh::DefaultStream()); + ncclInt8, r, comm->Handle(), comm->Stream()); if (!rc.OK()) { return rc; } diff --git a/src/collective/coll.cuh b/src/collective/coll.cuh index 6ededd101732..4d45295d7b84 100644 --- a/src/collective/coll.cuh +++ b/src/collective/coll.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once @@ -8,8 +8,7 @@ #include "../data/array_interface.h" // for ArrayInterfaceHandler #include "coll.h" // for Coll #include "comm.h" // for Comm -#include "nccl_stub.h" -#include "xgboost/span.h" // for Span +#include "xgboost/span.h" // for Span namespace xgboost::collective { class NCCLColl : public Coll { @@ -20,8 +19,7 @@ class NCCLColl : public Coll { ArrayInterfaceHandler::Type type, Op op) override; [[nodiscard]] Result Broadcast(Comm const& comm, common::Span data, std::int32_t root) override; - [[nodiscard]] Result Allgather(Comm const& comm, common::Span data, - std::int64_t size) override; + [[nodiscard]] Result Allgather(Comm const& comm, common::Span data) override; [[nodiscard]] Result AllgatherV(Comm const& comm, common::Span data, common::Span sizes, common::Span recv_segments, diff --git a/src/collective/coll.h b/src/collective/coll.h index 1afc8ed590f9..96fe35229510 100644 --- a/src/collective/coll.h +++ b/src/collective/coll.h @@ -48,10 +48,8 @@ class Coll : public std::enable_shared_from_this { * @brief Allgather * * @param [in,out] data Data buffer for input and output. - * @param [in] size Size of data for each worker. */ - [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span data, - std::int64_t size); + [[nodiscard]] virtual Result Allgather(Comm const& comm, common::Span data); /** * @brief Allgather with variable length. * diff --git a/src/collective/comm.cc b/src/collective/comm.cc index 783278b65f1c..543ece6397d8 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -1,16 +1,19 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "comm.h" #include // for copy #include // for seconds +#include // for int32_t #include // for exit #include // for shared_ptr #include // for string +#include // for thread #include // for move, forward - -#include "../common/common.h" // for AssertGPUSupport +#if !defined(XGBOOST_USE_NCCL) +#include "../common/common.h" // for AssertNCCLSupport +#endif // !defined(XGBOOST_USE_NCCL) #include "allgather.h" // for RingAllgather #include "protocol.h" // for kMagic #include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE @@ -21,11 +24,7 @@ namespace xgboost::collective { Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, std::int32_t retry, std::string task_id) - : timeout_{timeout}, - retry_{retry}, - tracker_{host, port, -1}, - task_id_{std::move(task_id)}, - loop_{std::shared_ptr{new Loop{timeout}}} {} + : timeout_{timeout}, retry_{retry}, tracker_{host, port, -1}, task_id_{std::move(task_id)} {} Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry, std::string const& task_id, TCPSocket* out, std::int32_t rank, @@ -75,9 +74,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st } << [&] { return next->NonBlocking(true); } << [&] { - SockAddrV4 addr; + SockAddress addr; return listener->Accept(prev.get(), &addr); - } << [&] { return prev->NonBlocking(true); }; + } << [&] { + return prev->NonBlocking(true); + }; if (!rc.OK()) { return rc; } @@ -149,24 +150,33 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st } auto rank = comm.Rank(); - auto n_bytes = worker->SendAll(&rank, sizeof(comm.Rank())); - if (n_bytes != sizeof(comm.Rank())) { - return Fail("Failed to send rank."); + std::size_t n_bytes{0}; + auto rc = worker->SendAll(&rank, sizeof(comm.Rank()), &n_bytes); + if (!rc.OK()) { + return rc; + } else if (n_bytes != sizeof(comm.Rank())) { + return Fail("Failed to send rank.", std::move(rc)); } workers[r] = std::move(worker); } for (std::int32_t r = 0; r < comm.Rank(); ++r) { - SockAddrV4 addr; auto peer = std::shared_ptr(TCPSocket::CreatePtr(comm.Domain())); - rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); } - << [&] { return peer->RecvTimeout(timeout); }; + rc = std::move(rc) << [&] { + SockAddress addr; + return listener->Accept(peer.get(), &addr); + } << [&] { + return peer->RecvTimeout(timeout); + }; if (!rc.OK()) { return rc; } std::int32_t rank{-1}; - auto n_bytes = peer->RecvAll(&rank, sizeof(rank)); - if (n_bytes != sizeof(comm.Rank())) { + std::size_t n_bytes{0}; + auto rc = peer->RecvAll(&rank, sizeof(rank), &n_bytes); + if (!rc.OK()) { + return rc; + } else if (n_bytes != sizeof(comm.Rank())) { return Fail("Failed to recv rank."); } workers[rank] = std::move(peer); @@ -182,12 +192,32 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st return Success(); } -RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, - std::int32_t retry, std::string task_id, StringView nccl_path) - : HostComm{std::move(host), port, timeout, retry, std::move(task_id)}, +namespace { +std::string InitLog(std::string task_id, std::int32_t rank) { + if (task_id.empty()) { + return "Rank " + std::to_string(rank); + } + return "Task " + task_id + " got rank " + std::to_string(rank); +} +} // namespace + +RabitComm::RabitComm(std::string const& tracker_host, std::int32_t tracker_port, + std::chrono::seconds timeout, std::int32_t retry, std::string task_id, + StringView nccl_path) + : HostComm{tracker_host, tracker_port, timeout, retry, std::move(task_id)}, nccl_path_{std::move(nccl_path)} { + if (this->TrackerInfo().host.empty()) { + // Not in a distributed environment. + LOG(CONSOLE) << InitLog(task_id_, rank_); + return; + } + + loop_.reset(new Loop{std::chrono::seconds{timeout_}}); // NOLINT auto rc = this->Bootstrap(timeout_, retry_, task_id_); - CHECK(rc.OK()) << rc.Report(); + if (!rc.OK()) { + this->ResetState(); + SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc))); + } } #if !defined(XGBOOST_USE_NCCL) @@ -212,20 +242,54 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { // Start command TCPSocket listener = TCPSocket::Create(tracker.Domain()); - std::int32_t lport = listener.BindHost(); - listener.Listen(); + std::int32_t lport{0}; + rc = std::move(rc) << [&] { + return listener.BindHost(&lport); + } << [&] { + return listener.Listen(); + }; + if (!rc.OK()) { + return rc; + } // create worker for listening to error notice. auto domain = tracker.Domain(); std::shared_ptr error_sock{TCPSocket::CreatePtr(domain)}; - auto eport = error_sock->BindHost(); - error_sock->Listen(); + std::int32_t eport{0}; + rc = std::move(rc) << [&] { + return error_sock->BindHost(&eport); + } << [&] { + return error_sock->Listen(); + }; + if (!rc.OK()) { + return rc; + } + error_port_ = eport; + error_worker_ = std::thread{[error_sock = std::move(error_sock)] { - auto conn = error_sock->Accept(); + TCPSocket conn; + SockAddress addr; + auto rc = error_sock->Accept(&conn, &addr); + // On Linux, a shutdown causes an invalid argument error; + if (rc.Code() == std::errc::invalid_argument) { + return; + } // On Windows, accept returns a closed socket after finalize. if (conn.IsClosed()) { return; } + // The error signal is from the tracker, while shutdown signal is from the shutdown method + // of the RabitComm class (this). + bool is_error{false}; + rc = proto::Error{}.RecvSignal(&conn, &is_error); + if (!rc.OK()) { + LOG(WARNING) << rc.Report(); + return; + } + if (!is_error) { + return; // shutdown + } + LOG(WARNING) << "Another worker is running into error."; #if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 // exit is nicer than abort as the former performs cleanups. @@ -234,6 +298,9 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { LOG(FATAL) << "abort"; #endif }}; + // The worker thread is detached here to avoid the need to handle it later during + // destruction. For C++, if a thread is not joined or detached, it will segfault during + // destruction. error_worker_.detach(); proto::Start start; @@ -246,11 +313,13 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { // get ring neighbors std::string snext; - tracker.Recv(&snext); + rc = tracker.Recv(&snext); + if (!rc.OK()) { + return Fail("Failed to receive the rank for the next worker.", std::move(rc)); + } auto jnext = Json::Load(StringView{snext}); proto::PeerInfo ninfo{jnext}; - // get the rank of this worker this->rank_ = BootstrapPrev(ninfo.rank, world); this->tracker_.rank = rank_; @@ -258,20 +327,27 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr) const { std::vector> workers; rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers); if (!rc.OK()) { - return rc; + return Fail("Failed to connect to other workers.", std::move(rc)); } CHECK(this->channels_.empty()); for (auto& w : workers) { if (w) { - rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); } - << [&] { return w->SetKeepAlive(); }; + rc = std::move(rc) << [&] { + return w->SetNoDelay(); + } << [&] { + return w->NonBlocking(true); + } << [&] { + return w->SetKeepAlive(); + }; } if (!rc.OK()) { return rc; } this->channels_.emplace_back(std::make_shared(*this, w)); } + + LOG(CONSOLE) << InitLog(task_id_, rank_); return rc; } @@ -279,6 +355,8 @@ RabitComm::~RabitComm() noexcept(false) { if (!this->IsDistributed()) { return; } + LOG(WARNING) << "The communicator is being destroyed without a call to shutdown first. This can " + "lead to undefined behaviour."; auto rc = this->Shutdown(); if (!rc.OK()) { LOG(WARNING) << rc.Report(); @@ -286,24 +364,52 @@ RabitComm::~RabitComm() noexcept(false) { } [[nodiscard]] Result RabitComm::Shutdown() { + if (!this->IsDistributed()) { + return Success(); + } + // Tell the tracker that this worker is shutting down. TCPSocket tracker; + // Tell the error hanlding thread that we are shutting down. + TCPSocket err_client; + return Success() << [&] { return ConnectTrackerImpl(tracker_, timeout_, retry_, task_id_, &tracker, Rank(), World()); } << [&] { return this->Block(); } << [&] { - Json jcmd{Object{}}; - jcmd["cmd"] = Integer{static_cast(proto::CMD::kShutdown)}; - auto scmd = Json::Dump(jcmd); - auto n_bytes = tracker.Send(scmd); - if (n_bytes != scmd.size()) { - return Fail("Faled to send cmd."); - } + return proto::ShutdownCMD{}.Send(&tracker); + } << [&] { + this->channels_.clear(); return Success(); + } << [&] { + // Use tracker address to determine whether we want to use IPv6. + auto taddr = MakeSockAddress(xgboost::StringView{this->tracker_.host}, this->tracker_.port); + // Shutdown the error handling thread. We signal the thread through socket, + // alternatively, we can get the native handle and use pthread_cancel. But using a + // socket seems to be clearer as we know what's happening. + auto const& addr = taddr.IsV4() ? SockAddrV4::Loopback().Addr() : SockAddrV6::Loopback().Addr(); + // We use hardcoded 10 seconds and 1 retry here since we are just connecting to a + // local socket. For a normal OS, this should be enough time to schedule the + // connection. + auto rc = Connect(StringView{addr}, this->error_port_, 1, + std::min(std::chrono::seconds{10}, timeout_), &err_client); + this->ResetState(); + if (!rc.OK()) { + return Fail("Failed to connect to the error socket.", std::move(rc)); + } + return rc; + } << [&] { + // We put error thread shutdown at the end so that we have a better chance to finish + // the previous more important steps. + return proto::Error{}.SignalShutdown(&err_client); }; } [[nodiscard]] Result RabitComm::LogTracker(std::string msg) const { + if (!this->IsDistributed()) { + LOG(CONSOLE) << msg; + return Success(); + } TCPSocket out; proto::Print print; return Success() << [&] { return this->ConnectTracker(&out); } @@ -311,8 +417,11 @@ RabitComm::~RabitComm() noexcept(false) { } [[nodiscard]] Result RabitComm::SignalError(Result const& res) { - TCPSocket out; - return Success() << [&] { return this->ConnectTracker(&out); } - << [&] { return proto::ErrorCMD{}.WorkerSend(&out, res); }; + TCPSocket tracker; + return Success() << [&] { + return this->ConnectTracker(&tracker); + } << [&] { + return proto::ErrorCMD{}.WorkerSend(&tracker, res); + }; } } // namespace xgboost::collective diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 56681253c170..6566f28fad91 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -27,7 +27,7 @@ Result GetUniqueId(Comm const& comm, std::shared_ptr stub, std::shared ncclUniqueId id; if (comm.Rank() == kRootRank) { auto rc = stub->GetUniqueId(&id); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); } auto rc = coll->Broadcast( comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, kRootRank); @@ -80,9 +80,8 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p auto s_this_uuid = s_uuid.subspan(root.Rank() * kUuidLength, kUuidLength); GetCudaUUID(s_this_uuid, ctx->Device()); - auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes()); - - CHECK(rc.OK()) << rc.Report(); + auto rc = pimpl->Allgather(root, common::EraseType(s_uuid)); + SafeColl(rc); std::vector> converted(root.World()); std::size_t j = 0; @@ -103,7 +102,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p [&] { return this->stub_->CommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()); }; - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t r = 0; r < root.World(); ++r) { this->channels_.emplace_back( @@ -114,7 +113,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p NCCLComm::~NCCLComm() { if (nccl_comm_) { auto rc = stub_->CommDestroy(nccl_comm_); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); } } } // namespace xgboost::collective diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index a818d95f8134..4add9ca612e0 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -50,6 +50,10 @@ class NCCLComm : public Comm { auto rc = this->Stream().Sync(false); return GetCUDAResult(rc); } + [[nodiscard]] Result Shutdown() final { + this->ResetState(); + return Success(); + } }; class NCCLChannel : public Channel { diff --git a/src/collective/comm.h b/src/collective/comm.h index 82aa2c45e0bb..0a0f24aadd3d 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -1,10 +1,10 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for seconds #include // for size_t -#include // for int32_t +#include // for int32_t, int64_t #include // for shared_ptr #include // for string #include // for thread @@ -14,13 +14,13 @@ #include "loop.h" // for Loop #include "protocol.h" // for PeerInfo #include "xgboost/collective/result.h" // for Result -#include "xgboost/collective/socket.h" // for TCPSocket +#include "xgboost/collective/socket.h" // for TCPSocket, GetHostName #include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { -inline constexpr std::int32_t DefaultTimeoutSec() { return 300; } // 5min +inline constexpr std::int64_t DefaultTimeoutSec() { return 300; } // 5min inline constexpr std::int32_t DefaultRetry() { return 3; } // indexing into the ring @@ -51,11 +51,25 @@ class Comm : public std::enable_shared_from_this { proto::PeerInfo tracker_; SockDomain domain_{SockDomain::kV4}; + std::thread error_worker_; + std::int32_t error_port_; + std::string task_id_; std::vector> channels_; - std::shared_ptr loop_{new Loop{std::chrono::seconds{ - DefaultTimeoutSec()}}}; // fixme: require federated comm to have a timeout + std::shared_ptr loop_{nullptr}; // fixme: require federated comm to have a timeout + + void ResetState() { + this->world_ = -1; + this->rank_ = 0; + this->timeout_ = std::chrono::seconds{DefaultTimeoutSec()}; + + tracker_ = proto::PeerInfo{}; + this->task_id_.clear(); + channels_.clear(); + + loop_.reset(); + } public: Comm() = default; @@ -75,10 +89,13 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] auto Retry() const { return retry_; } [[nodiscard]] auto TaskID() const { return task_id_; } - [[nodiscard]] auto Rank() const { return rank_; } - [[nodiscard]] auto World() const { return IsDistributed() ? world_ : 1; } - [[nodiscard]] bool IsDistributed() const { return world_ != -1; } - void Submit(Loop::Op op) const { loop_->Submit(op); } + [[nodiscard]] auto Rank() const noexcept { return rank_; } + [[nodiscard]] auto World() const noexcept { return IsDistributed() ? world_ : 1; } + [[nodiscard]] bool IsDistributed() const noexcept { return world_ != -1; } + void Submit(Loop::Op op) const { + CHECK(loop_); + loop_->Submit(std::move(op)); + } [[nodiscard]] virtual Result Block() const { return loop_->Block(); } [[nodiscard]] virtual std::shared_ptr Chan(std::int32_t rank) const { @@ -88,6 +105,14 @@ class Comm : public std::enable_shared_from_this { [[nodiscard]] virtual Result LogTracker(std::string msg) const = 0; [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } + /** + * @brief Get a string ID for the current process. + */ + [[nodiscard]] virtual Result ProcessorName(std::string* out) const { + auto rc = GetHostName(out); + return rc; + } + [[nodiscard]] virtual Result Shutdown() = 0; }; /** @@ -105,20 +130,20 @@ class RabitComm : public HostComm { [[nodiscard]] Result Bootstrap(std::chrono::seconds timeout, std::int32_t retry, std::string task_id); - [[nodiscard]] Result Shutdown(); public: // bootstrapping construction. RabitComm() = default; - // ctor for testing where environment is known. - RabitComm(std::string const& host, std::int32_t port, std::chrono::seconds timeout, - std::int32_t retry, std::string task_id, StringView nccl_path); + RabitComm(std::string const& tracker_host, std::int32_t tracker_port, + std::chrono::seconds timeout, std::int32_t retry, std::string task_id, + StringView nccl_path); ~RabitComm() noexcept(false) override; [[nodiscard]] bool IsFederated() const override { return false; } [[nodiscard]] Result LogTracker(std::string msg) const override; [[nodiscard]] Result SignalError(Result const&) override; + [[nodiscard]] Result Shutdown() final; [[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const override; }; diff --git a/src/collective/comm_group.cc b/src/collective/comm_group.cc index f7bbba7549d4..a9b58ecb5505 100644 --- a/src/collective/comm_group.cc +++ b/src/collective/comm_group.cc @@ -1,22 +1,21 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include "comm_group.h" #include // for transform +#include // for tolower #include // for seconds #include // for int32_t +#include // for back_inserter #include // for shared_ptr, unique_ptr #include // for string -#include // for vector -#include "../common/json_utils.h" // for OptionalArg -#include "coll.h" // for Coll -#include "comm.h" // for Comm -#include "tracker.h" // for GetHostAddress -#include "xgboost/collective/result.h" // for Result -#include "xgboost/context.h" // for DeviceOrd -#include "xgboost/json.h" // for Json +#include "../common/json_utils.h" // for OptionalArg +#include "coll.h" // for Coll +#include "comm.h" // for Comm +#include "xgboost/context.h" // for DeviceOrd +#include "xgboost/json.h" // for Json #if defined(XGBOOST_USE_FEDERATED) #include "../../plugin/federated/federated_coll.h" @@ -65,6 +64,9 @@ CommGroup::CommGroup() auto const& obj = get(config); auto it = obj.find(upper); + if (it != obj.cend() && obj.find(name) != obj.cend()) { + LOG(FATAL) << "Duplicated parameter:" << name; + } if (it != obj.cend()) { return OptionalArg(config, upper, dft); } else { @@ -74,18 +76,18 @@ CommGroup::CommGroup() // Common args auto retry = get_param("dmlc_retry", static_cast(DefaultRetry()), Integer{}); auto timeout = - get_param("dmlc_timeout_sec", static_cast(DefaultTimeoutSec()), Integer{}); + get_param("dmlc_timeout", static_cast(DefaultTimeoutSec()), Integer{}); auto task_id = get_param("dmlc_task_id", std::string{}, String{}); if (type == "rabit") { - auto host = get_param("dmlc_tracker_uri", std::string{}, String{}); - auto port = get_param("dmlc_tracker_port", static_cast(0), Integer{}); + auto tracker_host = get_param("dmlc_tracker_uri", std::string{}, String{}); + auto tracker_port = get_param("dmlc_tracker_port", static_cast(0), Integer{}); auto nccl = get_param("dmlc_nccl_path", std::string{DefaultNcclName()}, String{}); - auto ptr = - new CommGroup{std::shared_ptr{new RabitComm{ // NOLINT - host, static_cast(port), std::chrono::seconds{timeout}, - static_cast(retry), task_id, nccl}}, - std::shared_ptr(new Coll{})}; // NOLINT + auto ptr = new CommGroup{ + std::shared_ptr{new RabitComm{ // NOLINT + tracker_host, static_cast(tracker_port), std::chrono::seconds{timeout}, + static_cast(retry), task_id, nccl}}, + std::shared_ptr(new Coll{})}; // NOLINT return ptr; } else if (type == "federated") { #if defined(XGBOOST_USE_FEDERATED) @@ -117,6 +119,34 @@ void GlobalCommGroupInit(Json config) { void GlobalCommGroupFinalize() { auto& sptr = GlobalCommGroup(); + auto rc = sptr->Finalize(); sptr.reset(); + SafeColl(rc); +} + +void Init(Json const& config) { GlobalCommGroupInit(config); } + +void Finalize() { GlobalCommGroupFinalize(); } + +std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); } + +std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); } + +bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); } + +[[nodiscard]] bool IsFederated() { + return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated(); +} + +void Print(std::string const& message) { + auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message); + SafeColl(rc); +} + +std::string GetProcessorName() { + std::string out; + auto rc = GlobalCommGroup()->ProcessorName(&out); + SafeColl(rc); + return out; } } // namespace xgboost::collective diff --git a/src/collective/comm_group.h b/src/collective/comm_group.h index 2f6f91d73a79..a98de0c16e51 100644 --- a/src/collective/comm_group.h +++ b/src/collective/comm_group.h @@ -9,7 +9,6 @@ #include "coll.h" // for Comm #include "comm.h" // for Coll #include "xgboost/collective/result.h" // for Result -#include "xgboost/collective/socket.h" // for GetHostName namespace xgboost::collective { /** @@ -31,19 +30,35 @@ class CommGroup { public: CommGroup(); - [[nodiscard]] auto World() const { return comm_->World(); } - [[nodiscard]] auto Rank() const { return comm_->Rank(); } - [[nodiscard]] bool IsDistributed() const { return comm_->IsDistributed(); } + [[nodiscard]] auto World() const noexcept { return comm_->World(); } + [[nodiscard]] auto Rank() const noexcept { return comm_->Rank(); } + [[nodiscard]] bool IsDistributed() const noexcept { return comm_->IsDistributed(); } + + [[nodiscard]] Result Finalize() const { + return Success() << [this] { + if (gpu_comm_) { + return gpu_comm_->Shutdown(); + } + return Success(); + } << [&] { + return comm_->Shutdown(); + }; + } [[nodiscard]] static CommGroup* Create(Json config); [[nodiscard]] std::shared_ptr Backend(DeviceOrd device) const; + /** + * @brief Decide the context to use for communication. + * + * @param ctx Global context, provides the CUDA stream and ordinal. + * @param device The device used by the data to be communicated. + */ [[nodiscard]] Comm const& Ctx(Context const* ctx, DeviceOrd device) const; [[nodiscard]] Result SignalError(Result const& res) { return comm_->SignalError(res); } [[nodiscard]] Result ProcessorName(std::string* out) const { - auto rc = GetHostName(out); - return rc; + return this->comm_->ProcessorName(out); } }; diff --git a/src/collective/communicator-inl.cc b/src/collective/communicator-inl.cc deleted file mode 100644 index 4164855f1cef..000000000000 --- a/src/collective/communicator-inl.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2024, XGBoost contributors - */ -#include "communicator-inl.h" - -namespace xgboost::collective { -[[nodiscard]] std::vector> VectorAllgatherV( - std::vector> const &input) { - auto n_inputs = input.size(); - std::vector sizes(n_inputs); - std::transform(input.cbegin(), input.cend(), sizes.begin(), - [](auto const &vec) { return vec.size(); }); - - std::vector global_sizes = AllgatherV(sizes); - std::vector offset(global_sizes.size() + 1); - offset[0] = 0; - for (std::size_t i = 1; i < offset.size(); i++) { - offset[i] = offset[i - 1] + global_sizes[i - 1]; - } - - std::vector collected; - for (auto const &vec : input) { - collected.insert(collected.end(), vec.cbegin(), vec.cend()); - } - auto out = AllgatherV(collected); - - std::vector> result; - for (std::size_t i = 1; i < offset.size(); ++i) { - std::vector local(out.cbegin() + offset[i - 1], out.cbegin() + offset[i]); - result.emplace_back(std::move(local)); - } - return result; -} -} // namespace xgboost::collective diff --git a/src/collective/communicator-inl.cuh b/src/collective/communicator-inl.cuh deleted file mode 100644 index 200a9ff4aa01..000000000000 --- a/src/collective/communicator-inl.cuh +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2023 by XGBoost contributors - */ -#pragma once -#include -#include - -#include "communicator.h" -#include "device_communicator.cuh" - -namespace xgboost { -namespace collective { - -/** - * @brief Reduce values from all processes and distribute the result back to all processes. - * @param device ID of the device. - * @param send_receive_buffer Buffer storing the data. - * @param count Number of elements in the buffer. - */ -template -inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op); -} - -template -inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op); -} - -template -inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op); -} - -template -inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op); -} - -template -inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op); -} - -template -inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); -} - -template -inline void AllReduce(int device, float *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op); -} - -template -inline void AllReduce(int device, double *send_receive_buffer, size_t count) { - Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op); -} - -/** - * @brief Gather values from all all processes. - * - * This assumes all ranks have the same size. - * - * @param send_buffer Buffer storing the data to be sent. - * @param receive_buffer Buffer storing the gathered data. - * @param send_size Size of the sent data in bytes. - */ -inline void AllGather(int device, void const *send_buffer, void *receive_buffer, - std::size_t send_size) { - Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size); -} - -/** - * @brief Gather variable-length values from all processes. - * @param device ID of the device. - * @param send_buffer Buffer storing the input data. - * @param length_bytes Length in bytes of the input data. - * @param segments Size of each segment. - * @param receive_buffer Buffer storing the output data. - */ -inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *receive_buffer) { - Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer); -} - -/** - * @brief Synchronize device operations. - * @param device ID of the device. - */ -inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); } - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/communicator-inl.h b/src/collective/communicator-inl.h index 991e19f2c65a..2632007009ed 100644 --- a/src/collective/communicator-inl.h +++ b/src/collective/communicator-inl.h @@ -3,308 +3,63 @@ */ #pragma once #include -#include -#include "communicator.h" - -namespace xgboost { -namespace collective { +#include "xgboost/json.h" // for Json +namespace xgboost::collective { /** - * \brief Initialize the collective communicator. - * - * Currently the communicator API is experimental, function signatures may change in the future - * without notice. - * - * Call this once before using anything. - * - * The additional configuration is not required. Usually the communicator will detect settings - * from environment variables. - * - * \param json_config JSON encoded configuration. Accepted JSON keys are: - * - xgboost_communicator: The type of the communicator. Can be set as an environment variable. - * * rabit: Use Rabit. This is the default if the type is unspecified. - * * mpi: Use MPI. - * * federated: Use the gRPC interface for Federated Learning. - * Only applicable to the Rabit communicator (these are case-sensitive): - * - rabit_tracker_uri: Hostname of the tracker. - * - rabit_tracker_port: Port number of the tracker. - * - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment. - * - rabit_world_size: Total number of workers. - * - rabit_hadoop_mode: Enable Hadoop support. - * - rabit_tree_reduce_minsize: Minimal size for tree reduce. - * - rabit_reduce_ring_mincount: Minimal count to perform ring reduce. - * - rabit_reduce_buffer: Size of the reduce buffer. - * - rabit_bootstrap_cache: Size of the bootstrap cache. - * - rabit_debug: Enable debugging. - * - rabit_timeout: Enable timeout. - * - rabit_timeout_sec: Timeout in seconds. - * - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms. - * Only applicable to the Rabit communicator (these are case-sensitive, and can be set as - * environment variables): - * - DMLC_TRACKER_URI: Hostname of the tracker. - * - DMLC_TRACKER_PORT: Port number of the tracker. - * - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment. - * - DMLC_ROLE: Role of the current task, "worker" or "server". - * - DMLC_NUM_ATTEMPT: Number of attempts after task failure. - * - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker. - * Only applicable to the Federated communicator (use upper case for environment variables, use - * lower case for runtime configuration): - * - federated_server_address: Address of the federated server. - * - federated_world_size: Number of federated workers. - * - federated_rank: Rank of the current worker. - * - federated_server_cert: Server certificate file path. Only needed for the SSL mode. - * - federated_client_key: Client key file path. Only needed for the SSL mode. - * - federated_client_cert: Client certificate file path. Only needed for the SSL mode. + * @brief Initialize the collective communicator. */ -inline void Init(Json const &config) { Communicator::Init(config); } +void Init(Json const& config); -/*! - * \brief Finalize the collective communicator. +/** + * @brief Finalize the collective communicator. * * Call this function after you finished all jobs. */ -inline void Finalize() { Communicator::Finalize(); } - -/*! - * \brief Get rank of current process. - * - * \return Rank of the worker. - */ -inline int GetRank() { return Communicator::Get()->GetRank(); } - -/*! - * \brief Get total number of processes. - * - * \return Total world size. - */ -inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); } - -/*! - * \brief Get if the communicator is distributed. - * - * \return True if the communicator is distributed. - */ -inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); } - -/*! - * \brief Get if the communicator is federated. - * - * \return True if the communicator is federated. - */ -inline bool IsFederated() { return Communicator::Get()->IsFederated(); } - -/*! - * \brief Print the message to the communicator. - * - * This function can be used to communicate the information of the progress to the user who monitors - * the communicator. - * - * \param message The message to be printed. - */ -inline void Print(char const *message) { Communicator::Get()->Print(message); } - -inline void Print(std::string const &message) { Communicator::Get()->Print(message); } - -/*! - * \brief Get the name of the processor. - * - * \return Name of the processor. - */ -inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); } +void Finalize(); -/*! - * \brief Broadcast a memory region to all others from root. This function is NOT thread-safe. - * - * Example: - * int a = 1; - * Broadcast(&a, sizeof(a), root); +/** + * @brief Get rank of current process. * - * \param send_receive_buffer Pointer to the send or receive buffer. - * \param size Size of the data. - * \param root The process rank to broadcast from. + * @return Rank of the worker. */ -inline void Broadcast(void *send_receive_buffer, size_t size, int root) { - Communicator::Get()->Broadcast(send_receive_buffer, size, root); -} - -inline void Broadcast(std::string *sendrecv_data, int root) { - size_t size = sendrecv_data->length(); - Broadcast(&size, sizeof(size), root); - if (sendrecv_data->length() != size) { - sendrecv_data->resize(size); - } - if (size != 0) { - Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root); - } -} +[[nodiscard]] std::int32_t GetRank() noexcept; /** - * @brief Gathers a single value all processes and distributes the result to all processes. + * @brief Get total number of processes. * - * @param input The single value. + * @return Total world size. */ -template -inline std::vector Allgather(T const &input) { - std::string_view str_input{reinterpret_cast(&input), sizeof(T)}; - auto const output = Communicator::Get()->AllGather(str_input); - CHECK_EQ(output.size() % sizeof(T), 0); - std::vector result(output.size() / sizeof(T)); - std::memcpy(reinterpret_cast(result.data()), output.data(), output.size()); - return result; -} +[[nodiscard]] std::int32_t GetWorldSize() noexcept; /** - * @brief Gathers data from all processes and distributes it to all processes. - * - * This assumes all ranks have the same size. + * @brief Get if the communicator is distributed. * - * @param input Buffer storing the data. + * @return True if the communicator is distributed. */ -template -inline std::vector Allgather(std::vector const &input) { - if (input.empty()) { - return input; - } - std::string_view str_input{reinterpret_cast(input.data()), - input.size() * sizeof(T)}; - auto const output = Communicator::Get()->AllGather(str_input); - CHECK_EQ(output.size() % sizeof(T), 0); - std::vector result(output.size() / sizeof(T)); - std::memcpy(reinterpret_cast(result.data()), output.data(), output.size()); - return result; -} +[[nodiscard]] bool IsDistributed() noexcept; /** - * @brief Gathers variable-length data from all processes and distributes it to all processes. - * @param input Buffer storing the data. + * @brief Get if the communicator is federated. + * + * @return True if the communicator is federated. */ -template -inline std::vector AllgatherV(std::vector const &input) { - std::string_view str_input{reinterpret_cast(input.data()), - input.size() * sizeof(T)}; - auto const output = Communicator::Get()->AllGatherV(str_input); - CHECK_EQ(output.size() % sizeof(T), 0); - std::vector result(output.size() / sizeof(T)); - if (!output.empty()) { - std::memcpy(reinterpret_cast(result.data()), output.data(), output.size()); - } - return result; -} +[[nodiscard]] bool IsFederated(); /** - * @brief Gathers variable-length data from all processes and distributes it to all processes. + * @brief Print the message to the communicator. * - * @param inputs All the inputs from the local worker. The number of inputs can vary - * across different workers. Along with which, the size of each vector in - * the input can also vary. + * This function can be used to communicate the information of the progress to the user who monitors + * the communicator. * - * @return The AllgatherV result, containing vectors from all workers. + * @param message The message to be printed. */ -[[nodiscard]] std::vector> VectorAllgatherV( - std::vector> const &input); - +void Print(std::string const& message); /** - * @brief Gathers variable-length strings from all processes and distributes them to all processes. - * @param input Variable-length list of variable-length strings. - */ -inline std::vector AllgatherStrings(std::vector const &input) { - std::size_t total_size{0}; - for (auto const &s : input) { - total_size += s.length() + 1; // +1 for null-terminators - } - std::string flat_string; - flat_string.reserve(total_size); - for (auto const &s : input) { - flat_string.append(s); - flat_string.push_back('\0'); // Append a null-terminator after each string - } - - auto const output = Communicator::Get()->AllGatherV(flat_string); - - std::vector result; - std::size_t start_index = 0; - // Iterate through the output, find each null-terminated substring. - for (std::size_t i = 0; i < output.size(); i++) { - if (output[i] == '\0') { - // Construct a std::string from the char* substring - result.emplace_back(&output[start_index]); - // Move to the next substring - start_index = i + 1; - } - } - return result; -} - -/*! - * \brief Perform in-place allreduce. This function is NOT thread-safe. + * @brief Get the name of the processor. * - * Example Usage: the following code gives sum of the result - * vector data(10); - * ... - * Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum); - * ... - * \param send_receive_buffer Buffer for both sending and receiving data. - * \param count Number of elements to be reduced. - * \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h. - * \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h. + * @return Name of the processor. */ -inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) { - Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast(data_type), - static_cast(op)); -} - -inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) { - Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op); -} - -template -inline void Allreduce(int8_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op); -} - -template -inline void Allreduce(uint8_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op); -} - -template -inline void Allreduce(int32_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op); -} - -template -inline void Allreduce(uint32_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op); -} - -template -inline void Allreduce(int64_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op); -} - -template -inline void Allreduce(uint64_t *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); -} - -// Specialization for size_t, which is implementation defined, so it might or might not -// be one of uint64_t/uint32_t/unsigned long long/unsigned long. -template {} && !std::is_same{}> > -inline void Allreduce(T *send_receive_buffer, size_t count) { - static_assert(sizeof(T) == sizeof(uint64_t)); - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op); -} - -template -inline void Allreduce(float *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op); -} - -template -inline void Allreduce(double *send_receive_buffer, size_t count) { - Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op); -} -} // namespace collective -} // namespace xgboost +std::string GetProcessorName(); +} // namespace xgboost::collective diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc deleted file mode 100644 index 7fabe50b465d..000000000000 --- a/src/collective/communicator.cc +++ /dev/null @@ -1,63 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "communicator.h" - -#include "comm.h" -#include "in_memory_communicator.h" -#include "noop_communicator.h" -#include "rabit_communicator.h" - -#if defined(XGBOOST_USE_FEDERATED) -#include "../../plugin/federated/federated_communicator.h" -#endif - -namespace xgboost::collective { -thread_local std::unique_ptr Communicator::communicator_{new NoOpCommunicator()}; -thread_local CommunicatorType Communicator::type_{}; -thread_local std::string Communicator::nccl_path_{}; - -void Communicator::Init(Json const& config) { - auto nccl = OptionalArg(config, "dmlc_nccl_path", std::string{DefaultNcclName()}); - nccl_path_ = nccl; - - auto type = GetTypeFromEnv(); - auto const arg = GetTypeFromConfig(config); - if (arg != CommunicatorType::kUnknown) { - type = arg; - } - if (type == CommunicatorType::kUnknown) { - // Default to Rabit if unspecified. - type = CommunicatorType::kRabit; - } - type_ = type; - switch (type) { - case CommunicatorType::kRabit: { - communicator_.reset(RabitCommunicator::Create(config)); - break; - } - case CommunicatorType::kFederated: { -#if defined(XGBOOST_USE_FEDERATED) - communicator_.reset(FederatedCommunicator::Create(config)); -#else - LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; -#endif - break; - } - case CommunicatorType::kInMemory: - case CommunicatorType::kInMemoryNccl: { - communicator_.reset(InMemoryCommunicator::Create(config)); - break; - } - case CommunicatorType::kUnknown: - LOG(FATAL) << "Unknown communicator type."; - } -} - -#ifndef XGBOOST_USE_CUDA -void Communicator::Finalize() { - communicator_->Shutdown(); - communicator_.reset(new NoOpCommunicator()); -} -#endif -} // namespace xgboost::collective diff --git a/src/collective/communicator.cu b/src/collective/communicator.cu deleted file mode 100644 index a7552d35690e..000000000000 --- a/src/collective/communicator.cu +++ /dev/null @@ -1,54 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "communicator.h" -#include "device_communicator.cuh" -#include "device_communicator_adapter.cuh" -#include "noop_communicator.h" -#ifdef XGBOOST_USE_NCCL -#include "nccl_device_communicator.cuh" -#endif - -namespace xgboost { -namespace collective { - -thread_local std::unique_ptr Communicator::device_communicator_{}; - -void Communicator::Finalize() { - communicator_->Shutdown(); - communicator_.reset(new NoOpCommunicator()); - device_communicator_.reset(nullptr); -} - -DeviceCommunicator* Communicator::GetDevice(int device_ordinal) { - thread_local auto old_device_ordinal = -1; - // If the number of GPUs changes, we need to re-initialize NCCL. - thread_local auto old_world_size = -1; - if (!device_communicator_ || device_ordinal != old_device_ordinal || - communicator_->GetWorldSize() != old_world_size) { - old_device_ordinal = device_ordinal; - old_world_size = communicator_->GetWorldSize(); -#ifdef XGBOOST_USE_NCCL - switch (type_) { - case CommunicatorType::kRabit: - device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_)); - break; - case CommunicatorType::kFederated: - case CommunicatorType::kInMemory: - device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); - break; - case CommunicatorType::kInMemoryNccl: - device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true, nccl_path_)); - break; - default: - device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false, nccl_path_)); - } -#else - device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal)); -#endif - } - return device_communicator_.get(); -} - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/communicator.h b/src/collective/communicator.h deleted file mode 100644 index b6910b80f1fd..000000000000 --- a/src/collective/communicator.h +++ /dev/null @@ -1,247 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include -#include - -#include -#include - -namespace xgboost { -namespace collective { - -/** @brief Defines the integral and floating data types. */ -enum class DataType { - kInt8 = 0, - kUInt8 = 1, - kInt32 = 2, - kUInt32 = 3, - kInt64 = 4, - kUInt64 = 5, - kFloat = 6, - kDouble = 7 -}; - -/** @brief Get the size of the data type. */ -inline std::size_t GetTypeSize(DataType data_type) { - std::size_t size{0}; - switch (data_type) { - case DataType::kInt8: - size = sizeof(std::int8_t); - break; - case DataType::kUInt8: - size = sizeof(std::uint8_t); - break; - case DataType::kInt32: - size = sizeof(std::int32_t); - break; - case DataType::kUInt32: - size = sizeof(std::uint32_t); - break; - case DataType::kInt64: - size = sizeof(std::int64_t); - break; - case DataType::kUInt64: - size = sizeof(std::uint64_t); - break; - case DataType::kFloat: - size = sizeof(float); - break; - case DataType::kDouble: - size = sizeof(double); - break; - default: - LOG(FATAL) << "Unknown data type."; - } - return size; -} - -/** @brief Defines the reduction operation. */ -enum class Operation { - kMax = 0, - kMin = 1, - kSum = 2, - kBitwiseAND = 3, - kBitwiseOR = 4, - kBitwiseXOR = 5 -}; - -class DeviceCommunicator; - -enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl }; - -/** \brief Case-insensitive string comparison. */ -inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) { -#ifdef _MSC_VER - return _stricmp(s1, s2); -#else // _MSC_VER - return strcasecmp(s1, s2); -#endif // _MSC_VER -} - -/** - * @brief A communicator class that handles collective communication. - */ -class Communicator { - public: - /** - * @brief Initialize the communicator. This can only be done once. - * - * @param config JSON configuration for the communicator. - */ - static void Init(Json const &config); - - /** @brief Finalize the communicator. */ - static void Finalize(); - - /** @brief Get the communicator instance. */ - static Communicator *Get() { return communicator_.get(); } - -#if defined(XGBOOST_USE_CUDA) - /** - * @brief Get the device communicator. - * - * @param device_ordinal ID of the device. - * @return An instance of device communicator. - */ - static DeviceCommunicator *GetDevice(int device_ordinal); -#endif - - virtual ~Communicator() = default; - - /** @brief Get the total number of processes. */ - int GetWorldSize() const { return world_size_; } - - /** @brief Get the rank of the current processes. */ - int GetRank() const { return rank_; } - - /** @brief Whether the communicator is running in distributed mode. */ - virtual bool IsDistributed() const = 0; - - /** @brief Whether the communicator is running in federated mode. */ - virtual bool IsFederated() const = 0; - - /** - * @brief Gathers data from all processes and distributes it to all processes. - * - * This assumes all ranks have the same size. - * - * @param input Buffer storing the data. - */ - virtual std::string AllGather(std::string_view input) = 0; - - /** - * @brief Gathers variable-length data from all processes and distributes it to all processes. - * @param input Buffer storing the data. - */ - virtual std::string AllGatherV(std::string_view input) = 0; - - /** - * @brief Combines values from all processes and distributes the result back to all processes. - * - * @param send_receive_buffer Buffer storing the data. - * @param count Number of elements in the buffer. - * @param data_type Data type stored in the buffer. - * @param op The operation to perform. - */ - virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) = 0; - - /** - * @brief Broadcasts a message from the process with rank `root` to all other processes of the - * group. - * - * @param send_receive_buffer Buffer storing the data. - * @param size Size of the data in bytes. - * @param root Rank of broadcast root. - */ - virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0; - - /** - * @brief Gets the name of the processor. - */ - virtual std::string GetProcessorName() = 0; - - /** - * @brief Prints the message. - */ - virtual void Print(std::string const &message) = 0; - - /** @brief Get the communicator type from environment variables. Visible for testing. */ - static CommunicatorType GetTypeFromEnv() { - auto *env = std::getenv("XGBOOST_COMMUNICATOR"); - if (env != nullptr) { - return StringToType(env); - } else { - return CommunicatorType::kUnknown; - } - } - - /** @brief Get the communicator type from runtime configuration. Visible for testing. */ - static CommunicatorType GetTypeFromConfig(Json const &config) { - auto const &j_upper = config["XGBOOST_COMMUNICATOR"]; - if (IsA(j_upper)) { - return StringToType(get(j_upper).c_str()); - } - auto const &j_lower = config["xgboost_communicator"]; - if (IsA(j_lower)) { - return StringToType(get(j_lower).c_str()); - } - return CommunicatorType::kUnknown; - } - - protected: - /** - * @brief Construct a new communicator. - * - * @param world_size Total number of processes. - * @param rank Rank of the current process. - */ - Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) { - if (world_size < 1) { - LOG(FATAL) << "World size " << world_size << " is less than 1."; - } - if (rank < 0) { - LOG(FATAL) << "Rank " << rank << " is less than 0."; - } - if (rank >= world_size) { - LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << "."; - } - } - - /** - * @brief Shuts down the communicator. - */ - virtual void Shutdown() = 0; - - private: - static CommunicatorType StringToType(char const *str) { - CommunicatorType result = CommunicatorType::kUnknown; - if (!CompareStringsCaseInsensitive("rabit", str)) { - result = CommunicatorType::kRabit; - } else if (!CompareStringsCaseInsensitive("federated", str)) { - result = CommunicatorType::kFederated; - } else if (!CompareStringsCaseInsensitive("in-memory", str)) { - result = CommunicatorType::kInMemory; - } else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) { - result = CommunicatorType::kInMemoryNccl; - } else { - LOG(FATAL) << "Unknown communicator type " << str; - } - return result; - } - - static thread_local std::unique_ptr communicator_; - static thread_local CommunicatorType type_; - static thread_local std::string nccl_path_; -#if defined(XGBOOST_USE_CUDA) - static thread_local std::unique_ptr device_communicator_; -#endif - - int const world_size_; - int const rank_; -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/device_communicator.cuh b/src/collective/device_communicator.cuh deleted file mode 100644 index 69094b382591..000000000000 --- a/src/collective/device_communicator.cuh +++ /dev/null @@ -1,57 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include - -#include "../common/device_helpers.cuh" - -namespace xgboost { -namespace collective { - -/** - * @brief Collective communicator for device buffers. - */ -class DeviceCommunicator { - public: - virtual ~DeviceCommunicator() = default; - - /** - * @brief Combines values from all processes and distributes the result back to all processes. - * - * @param send_receive_buffer Buffer storing the data. - * @param count Number of elements in the buffer. - * @param data_type Data type stored in the buffer. - * @param op The operation to perform. - */ - virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) = 0; - - /** - * @brief Gather values from all all processes. - * - * This assumes all ranks have the same size. - * - * @param send_buffer Buffer storing the data to be sent. - * @param receive_buffer Buffer storing the gathered data. - * @param send_size Size of the sent data in bytes. - */ - virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0; - - /** - * @brief Gather variable-length values from all processes. - * @param send_buffer Buffer storing the input data. - * @param length_bytes Length in bytes of the input data. - * @param segments Size of each segment. - * @param receive_buffer Buffer storing the output data. - */ - virtual void AllGatherV(void const *send_buffer, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *receive_buffer) = 0; - - /** @brief Synchronize device operations. */ - virtual void Synchronize() = 0; -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/device_communicator_adapter.cuh b/src/collective/device_communicator_adapter.cuh deleted file mode 100644 index 7d3e836a0ec9..000000000000 --- a/src/collective/device_communicator_adapter.cuh +++ /dev/null @@ -1,92 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once - -#include "communicator.h" -#include "device_communicator.cuh" - -namespace xgboost { -namespace collective { - -class DeviceCommunicatorAdapter : public DeviceCommunicator { - public: - explicit DeviceCommunicatorAdapter(int device_ordinal) - : device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} { - if (device_ordinal_ < 0) { - LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; - } - } - - ~DeviceCommunicatorAdapter() override = default; - - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto size = count * GetTypeSize(data_type); - host_buffer_.resize(size); - dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault)); - Allreduce(host_buffer_.data(), count, data_type, op); - dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault)); - } - - void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - host_buffer_.resize(send_size); - dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault)); - auto const output = Allgather(host_buffer_); - dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault)); - } - - void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, - dh::caching_device_vector *receive_buffer) override { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - - segments->clear(); - segments->resize(world_size_, 0); - segments->at(rank_) = length_bytes; - Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax); - auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); - receive_buffer->resize(total_bytes); - - host_buffer_.resize(total_bytes); - size_t offset = 0; - for (int32_t i = 0; i < world_size_; ++i) { - size_t as_bytes = segments->at(i); - if (i == rank_) { - dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_), - cudaMemcpyDefault)); - } - Broadcast(host_buffer_.data() + offset, as_bytes, i); - offset += as_bytes; - } - dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes, - cudaMemcpyDefault)); - } - - void Synchronize() override { - // Noop. - } - - private: - int const device_ordinal_; - int const world_size_; - int const rank_; - /// Host buffer used to call communicator functions. - std::vector host_buffer_{}; -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/in_memory_communicator.cc b/src/collective/in_memory_communicator.cc deleted file mode 100644 index 535a15bc9e1a..000000000000 --- a/src/collective/in_memory_communicator.cc +++ /dev/null @@ -1,12 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "in_memory_communicator.h" - -namespace xgboost { -namespace collective { - -InMemoryHandler InMemoryCommunicator::handler_{}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/in_memory_communicator.h b/src/collective/in_memory_communicator.h index c712d32a8006..bd89be6e32dc 100644 --- a/src/collective/in_memory_communicator.h +++ b/src/collective/in_memory_communicator.h @@ -15,14 +15,14 @@ namespace collective { /** * An in-memory communicator, useful for testing. */ -class InMemoryCommunicator : public Communicator { +class InMemoryCommunicator { public: /** * @brief Create a new communicator based on JSON configuration. * @param config JSON configuration. * @return Communicator as specified by the JSON configuration. */ - static Communicator* Create(Json const& config) { + static InMemoryCommunicator* Create(Json const& config) { int world_size{0}; int rank{-1}; @@ -51,7 +51,7 @@ class InMemoryCommunicator : public Communicator { return new InMemoryCommunicator(world_size, rank); } - InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) { + InMemoryCommunicator(int world_size, int rank) { handler_.Init(world_size, rank); } diff --git a/src/collective/in_memory_handler.cc b/src/collective/in_memory_handler.cc index 944e5077b068..468f09c53048 100644 --- a/src/collective/in_memory_handler.cc +++ b/src/collective/in_memory_handler.cc @@ -1,14 +1,13 @@ -/*! - * Copyright 2022 XGBoost contributors +/** + * Copyright 2022-2023, XGBoost contributors */ #include "in_memory_handler.h" #include #include +#include "comm.h" -namespace xgboost { -namespace collective { - +namespace xgboost::collective { /** * @brief Functor for allgather. */ @@ -16,7 +15,7 @@ class AllgatherFunctor { public: std::string const name{"Allgather"}; - AllgatherFunctor(std::size_t world_size, std::size_t rank) + AllgatherFunctor(std::int32_t world_size, std::int32_t rank) : world_size_{world_size}, rank_{rank} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { @@ -30,8 +29,8 @@ class AllgatherFunctor { } private: - std::size_t world_size_; - std::size_t rank_; + std::int32_t world_size_; + std::int32_t rank_; }; /** @@ -41,13 +40,13 @@ class AllgatherVFunctor { public: std::string const name{"AllgatherV"}; - AllgatherVFunctor(std::size_t world_size, std::size_t rank, + AllgatherVFunctor(std::int32_t world_size, std::int32_t rank, std::map* data) : world_size_{world_size}, rank_{rank}, data_{data} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { data_->emplace(rank_, std::string_view{input, bytes}); - if (data_->size() == world_size_) { + if (data_->size() == static_cast(world_size_)) { for (auto const& kv : *data_) { buffer->append(kv.second); } @@ -56,8 +55,8 @@ class AllgatherVFunctor { } private: - std::size_t world_size_; - std::size_t rank_; + std::int32_t world_size_; + std::int32_t rank_; std::map* data_; }; @@ -68,7 +67,7 @@ class AllreduceFunctor { public: std::string const name{"Allreduce"}; - AllreduceFunctor(DataType dataType, Operation operation) + AllreduceFunctor(ArrayInterfaceHandler::Type dataType, Op operation) : data_type_{dataType}, operation_{operation} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { @@ -76,23 +75,23 @@ class AllreduceFunctor { // Copy the input if this is the first request. buffer->assign(input, bytes); } else { + auto n_bytes_type = DispatchDType(data_type_, [](auto t) { return sizeof(t); }); // Apply the reduce_operation to the input and the buffer. - Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front()); + Accumulate(input, bytes / n_bytes_type, &buffer->front()); } } private: template ::value>* = nullptr> - void AccumulateBitwise(T* buffer, T const* input, std::size_t size, - Operation reduce_operation) const { + void AccumulateBitwise(T* buffer, T const* input, std::size_t size, Op reduce_operation) const { switch (reduce_operation) { - case Operation::kBitwiseAND: + case Op::kBitwiseAND: std::transform(buffer, buffer + size, input, buffer, std::bit_and()); break; - case Operation::kBitwiseOR: + case Op::kBitwiseOR: std::transform(buffer, buffer + size, input, buffer, std::bit_or()); break; - case Operation::kBitwiseXOR: + case Op::kBitwiseXOR: std::transform(buffer, buffer + size, input, buffer, std::bit_xor()); break; default: @@ -101,27 +100,27 @@ class AllreduceFunctor { } template ::value>* = nullptr> - void AccumulateBitwise(T*, T const*, std::size_t, Operation) const { + void AccumulateBitwise(T*, T const*, std::size_t, Op) const { LOG(FATAL) << "Floating point types do not support bitwise operations."; } template - void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const { + void Accumulate(T* buffer, T const* input, std::size_t size, Op reduce_operation) const { switch (reduce_operation) { - case Operation::kMax: + case Op::kMax: std::transform(buffer, buffer + size, input, buffer, [](T a, T b) { return std::max(a, b); }); break; - case Operation::kMin: + case Op::kMin: std::transform(buffer, buffer + size, input, buffer, [](T a, T b) { return std::min(a, b); }); break; - case Operation::kSum: + case Op::kSum: std::transform(buffer, buffer + size, input, buffer, std::plus()); break; - case Operation::kBitwiseAND: - case Operation::kBitwiseOR: - case Operation::kBitwiseXOR: + case Op::kBitwiseAND: + case Op::kBitwiseOR: + case Op::kBitwiseXOR: AccumulateBitwise(buffer, input, size, reduce_operation); break; default: @@ -130,36 +129,37 @@ class AllreduceFunctor { } void Accumulate(char const* input, std::size_t size, char* buffer) const { + using Type = ArrayInterfaceHandler::Type; switch (data_type_) { - case DataType::kInt8: + case Type::kI1: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kUInt8: + case Type::kU1: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kInt32: + case Type::kI4: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kUInt32: + case Type::kU4: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kInt64: + case Type::kI8: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kUInt64: + case Type::kU8: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kFloat: + case Type::kF4: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; - case DataType::kDouble: + case Type::kF8: Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, operation_); break; @@ -169,8 +169,8 @@ class AllreduceFunctor { } private: - DataType data_type_; - Operation operation_; + ArrayInterfaceHandler::Type data_type_; + Op operation_; }; /** @@ -180,7 +180,7 @@ class BroadcastFunctor { public: std::string const name{"Broadcast"}; - BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {} + BroadcastFunctor(std::int32_t rank, std::int32_t root) : rank_{rank}, root_{root} {} void operator()(char const* input, std::size_t bytes, std::string* buffer) const { if (rank_ == root_) { @@ -190,11 +190,11 @@ class BroadcastFunctor { } private: - std::size_t rank_; - std::size_t root_; + std::int32_t rank_; + std::int32_t root_; }; -void InMemoryHandler::Init(std::size_t world_size, std::size_t) { +void InMemoryHandler::Init(std::int32_t world_size, std::int32_t) { CHECK(world_size_ < world_size) << "In memory handler already initialized."; std::unique_lock lock(mutex_); @@ -204,7 +204,7 @@ void InMemoryHandler::Init(std::size_t world_size, std::size_t) { cv_.notify_all(); } -void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) { +void InMemoryHandler::Shutdown(uint64_t sequence_number, std::int32_t) { CHECK(world_size_ > 0) << "In memory handler already shutdown."; std::unique_lock lock(mutex_); @@ -220,29 +220,29 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) { } void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank) { + std::size_t sequence_number, std::int32_t rank) { Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank}); } void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank) { + std::size_t sequence_number, std::int32_t rank) { Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_}); } void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, DataType data_type, - Operation op) { + std::size_t sequence_number, std::int32_t rank, + ArrayInterfaceHandler::Type data_type, Op op) { Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op}); } void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, std::size_t root) { + std::size_t sequence_number, std::int32_t rank, std::int32_t root) { Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root}); } template void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, + std::size_t sequence_number, std::int32_t rank, HandlerFunctor const& functor) { // Pass through if there is only 1 client. if (world_size_ == 1) { @@ -287,5 +287,4 @@ void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* cv_.notify_all(); } } -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/src/collective/in_memory_handler.h b/src/collective/in_memory_handler.h index f9ac520079fd..7c3465d08b8b 100644 --- a/src/collective/in_memory_handler.h +++ b/src/collective/in_memory_handler.h @@ -1,16 +1,15 @@ -/*! - * Copyright 2022 XGBoost contributors +/** + * Copyright 2022-2023, XGBoost contributors */ #pragma once #include #include #include -#include "communicator.h" - -namespace xgboost { -namespace collective { +#include "../data/array_interface.h" +#include "comm.h" +namespace xgboost::collective { /** * @brief Handles collective communication primitives in memory. * @@ -28,11 +27,11 @@ class InMemoryHandler { /** * @brief Construct a handler with the given world size. - * @param world_size Number of workers. + * @param world Number of workers. * * This is used when the handler only needs to be initialized once with a known world size. */ - explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {} + explicit InMemoryHandler(std::int32_t world) : world_size_{world} {} /** * @brief Initialize the handler with the world size and rank. @@ -42,7 +41,7 @@ class InMemoryHandler { * This is used when multiple objects/threads are accessing the same handler and need to * initialize it collectively. */ - void Init(std::size_t world_size, std::size_t rank); + void Init(std::int32_t world_size, std::int32_t rank); /** * @brief Shut down the handler. @@ -52,7 +51,7 @@ class InMemoryHandler { * This is used when multiple objects/threads are accessing the same handler and need to * shut it down collectively. */ - void Shutdown(uint64_t sequence_number, std::size_t rank); + void Shutdown(uint64_t sequence_number, std::int32_t rank); /** * @brief Perform allgather. @@ -63,7 +62,7 @@ class InMemoryHandler { * @param rank Index of the worker. */ void Allgather(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank); + std::size_t sequence_number, std::int32_t rank); /** * @brief Perform variable-length allgather. @@ -74,7 +73,7 @@ class InMemoryHandler { * @param rank Index of the worker. */ void AllgatherV(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank); + std::size_t sequence_number, std::int32_t rank); /** * @brief Perform allreduce. @@ -87,7 +86,8 @@ class InMemoryHandler { * @param op The reduce operation. */ void Allreduce(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op); + std::size_t sequence_number, std::int32_t rank, + ArrayInterfaceHandler::Type data_type, Op op); /** * @brief Perform broadcast. @@ -99,7 +99,7 @@ class InMemoryHandler { * @param root Index of the worker to broadcast from. */ void Broadcast(char const* input, std::size_t bytes, std::string* output, - std::size_t sequence_number, std::size_t rank, std::size_t root); + std::size_t sequence_number, std::int32_t rank, std::int32_t root); private: /** @@ -114,17 +114,15 @@ class InMemoryHandler { */ template void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number, - std::size_t rank, HandlerFunctor const& functor); + std::int32_t rank, HandlerFunctor const& functor); - std::size_t world_size_{}; /// Number of workers. - std::size_t received_{}; /// Number of calls received with the current sequence. - std::size_t sent_{}; /// Number of calls completed with the current sequence. + std::int32_t world_size_{}; /// Number of workers. + std::int64_t received_{}; /// Number of calls received with the current sequence. + std::int64_t sent_{}; /// Number of calls completed with the current sequence. std::string buffer_{}; /// A shared common buffer. std::map aux_{}; /// A shared auxiliary map. uint64_t sequence_number_{}; /// Call sequence number. mutable std::mutex mutex_; /// Lock. mutable std::condition_variable cv_; /// Conditional variable to wait on. }; - -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/src/collective/loop.cc b/src/collective/loop.cc index b51749fcdad5..1c384bb2814a 100644 --- a/src/collective/loop.cc +++ b/src/collective/loop.cc @@ -6,6 +6,8 @@ #include // for size_t #include // for int32_t #include // for exception, current_exception, rethrow_exception +#include // for promise +#include // for make_shared #include // for lock_guard, unique_lock #include // for queue #include // for string @@ -18,9 +20,12 @@ #include "xgboost/logging.h" // for CHECK namespace xgboost::collective { -Result Loop::EmptyQueue(std::queue* p_queue) const { +Result Loop::ProcessQueue(std::queue* p_queue) const { timer_.Start(__func__); - auto error = [this] { timer_.Stop(__func__); }; + auto error = [this](Op op) { + op.pr->set_value(); + timer_.Stop(__func__); + }; if (stop_) { timer_.Stop(__func__); @@ -36,7 +41,7 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { // Iterate through all the ops for poll for (std::size_t i = 0; i < n_ops; ++i) { - auto op = qcopy.front(); + auto op = std::move(qcopy.front()); qcopy.pop(); switch (op.code) { @@ -48,39 +53,53 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { poll.WatchWrite(*op.sock); break; } + case Op::kSleep: { + break; + } default: { - error(); + error(op); return Fail("Invalid socket operation."); } } - qcopy.push(op); + qcopy.push(std::move(op)); } // poll, work on fds that are ready. timer_.Start("poll"); - auto rc = poll.Poll(timeout_); - timer_.Stop("poll"); - if (!rc.OK()) { - error(); - return rc; + if (!poll.fds.empty()) { + auto rc = poll.Poll(timeout_); + if (!rc.OK()) { + timer_.Stop(__func__); + return rc; + } } + timer_.Stop("poll"); - // we wonldn't be here if the queue is empty. + // We wonldn't be here if the queue is empty. CHECK(!qcopy.empty()); // Iterate through all the ops for performing the operations for (std::size_t i = 0; i < n_ops; ++i) { - auto op = qcopy.front(); + auto op = std::move(qcopy.front()); qcopy.pop(); std::int32_t n_bytes_done{0}; - CHECK(op.sock->NonBlocking()); + if (!op.sock) { + CHECK(op.code == Op::kSleep); + } else { + CHECK(op.sock->NonBlocking()); + } switch (op.code) { case Op::kRead: { if (poll.CheckRead(*op.sock)) { n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off); + if (n_bytes_done == 0) { + error(op); + return Fail("Encountered EOF. The other end is likely closed.", + op.sock->GetSockError()); + } } break; } @@ -90,15 +109,21 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { } break; } + case Op::kSleep: { + // For testing only. + std::this_thread::sleep_for(std::chrono::seconds{op.n}); + n_bytes_done = op.n; + break; + } default: { - error(); + error(op); return Fail("Invalid socket operation."); } } if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) { auto rc = system::FailWithCode("Invalid socket output."); - error(); + error(op); return rc; } @@ -106,8 +131,10 @@ Result Loop::EmptyQueue(std::queue* p_queue) const { CHECK_LE(op.off, op.n); if (op.off != op.n) { - // not yet finished, push back to queue for next round. + // not yet finished, push back to queue for the next round. qcopy.push(op); + } else { + op.pr->set_value(); } } } @@ -123,11 +150,19 @@ void Loop::Process() { }; // This loop cannot exit unless `stop_` is set to true. There must always be a thread to - // answer the blocking call even if there are errors, otherwise the blocking will wait - // forever. + // answer the call even if there are errors. while (true) { try { std::unique_lock lock{mu_}; + // This can handle missed notification: wait(lock, predicate) is equivalent to: + // + // while (!predicate()) { + // cv.wait(lock); + // } + // + // As a result, if there's a missed notification, the queue wouldn't be empty, hence + // the predicate would be false and the actual wait wouldn't be invoked. Therefore, + // the blocking call can never go unanswered. cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); if (stop_) { break; // only point where this loop can exit. @@ -136,43 +171,15 @@ void Loop::Process() { // Move the global queue into a local variable to unblock it. std::queue qcopy; - bool is_blocking = false; while (!queue_.empty()) { - auto op = queue_.front(); + auto op = std::move(queue_.front()); queue_.pop(); - if (op.code == Op::kBlock) { - is_blocking = true; - // Block must be the last op in the current batch since no further submit can be - // issued until the blocking call is finished. - CHECK(queue_.empty()); - } else { - qcopy.push(op); - } - } - - if (!is_blocking) { - // Unblock, we can write to the global queue again. - lock.unlock(); - } - - // Clear the local queue, this is blocking the current worker thread (but not the - // client thread), wait until all operations are finished. - auto rc = this->EmptyQueue(&qcopy); - - if (is_blocking) { - // The unlock is delayed if this is a blocking call - lock.unlock(); + qcopy.push(op); } + lock.unlock(); - // Notify the client thread who called block after all error conditions are set. - auto notify_if_block = [&] { - if (is_blocking) { - std::unique_lock lock{mu_}; - block_done_ = true; - lock.unlock(); - block_cv_.notify_one(); - } - }; + // Clear the local queue. + auto rc = this->ProcessQueue(&qcopy); // Handle error if (!rc.OK()) { @@ -180,8 +187,6 @@ void Loop::Process() { } else { CHECK(qcopy.empty()); } - - notify_if_block(); } catch (std::exception const& e) { curr_exce_ = std::current_exception(); set_rc(Fail("Exception inside the event loop:" + std::string{e.what()})); @@ -221,21 +226,28 @@ Result Loop::Stop() { stop_ = true; } } - if (!this->worker_.joinable()) { std::lock_guard guard{rc_lock_}; return Fail("Worker has stopped.", std::move(rc_)); } - this->Submit(Op{Op::kBlock}); - { - // Wait for the block call to finish. std::unique_lock lock{mu_}; - block_cv_.wait(lock, [this] { return block_done_ || stop_; }); - block_done_ = false; + cv_.notify_one(); } + for (auto& fut : futures_) { + if (fut.valid()) { + try { + fut.get(); + } catch (std::future_error const&) { + // Do nothing. If something went wrong in the worker, we have a std::future_error + // due to broken promise. This function will transfer the rc back to the caller. + } + } + } + futures_.clear(); + { // Transfer the rc. std::lock_guard lock{rc_lock_}; @@ -243,8 +255,20 @@ Result Loop::Stop() { } } +void Loop::Submit(Op op) { + auto p = std::make_shared>(); + op.pr = std::move(p); + futures_.emplace_back(op.pr->get_future()); + CHECK_NE(op.n, 0); + + std::unique_lock lock{mu_}; + queue_.push(op); +} + Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { timer_.Init(__func__); - worker_ = std::thread{[this] { this->Process(); }}; + worker_ = std::thread{[this] { + this->Process(); + }}; } } // namespace xgboost::collective diff --git a/src/collective/loop.h b/src/collective/loop.h index 4839abfd3917..0a830eb960b1 100644 --- a/src/collective/loop.h +++ b/src/collective/loop.h @@ -7,9 +7,12 @@ #include // for size_t #include // for int8_t, int32_t #include // for exception_ptr -#include // for unique_lock, mutex +#include // for future +#include // for shared_ptr +#include // for mutex #include // for queue #include // for thread +#include // for vector #include "../common/timer.h" // for Monitor #include "xgboost/collective/result.h" // for Result @@ -19,31 +22,38 @@ namespace xgboost::collective { class Loop { public: struct Op { - enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code; + // kSleep is only for testing + enum Code : std::int8_t { kRead = 0, kWrite = 1, kSleep = 3 } code; std::int32_t rank{-1}; std::int8_t* ptr{nullptr}; std::size_t n{0}; TCPSocket* sock{nullptr}; std::size_t off{0}; + std::shared_ptr> pr; - explicit Op(Code c) : code{c} { CHECK(c == kBlock); } + explicit Op(Code c) : code{c} { CHECK(c == kSleep); } Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} Op(Op const&) = default; Op& operator=(Op const&) = default; Op(Op&&) = default; Op& operator=(Op&&) = default; + // For testing purpose only + [[nodiscard]] static Op Sleep(std::size_t seconds) { + Op op{kSleep}; + op.n = seconds; + return op; + } }; private: std::thread worker_; // thread worker to execute the tasks - std::condition_variable cv_; // CV used to notify a new submit call - std::condition_variable block_cv_; // CV used to notify the blocking call - bool block_done_{false}; // Flag to indicate whether the blocking call has finished. + std::condition_variable cv_; // CV used to notify a new submit call std::queue queue_; // event queue - std::mutex mu_; // mutex to protect the queue, cv, and block_done + std::vector> futures_; + std::mutex mu_; // mutex to protect the queue, cv, and block_done std::chrono::seconds timeout_; @@ -54,7 +64,7 @@ class Loop { std::exception_ptr curr_exce_{nullptr}; common::Monitor mutable timer_; - Result EmptyQueue(std::queue* p_queue) const; + Result ProcessQueue(std::queue* p_queue) const; // The cunsumer function that runs inside a worker thread. void Process(); @@ -64,12 +74,7 @@ class Loop { */ Result Stop(); - void Submit(Op op) { - std::unique_lock lock{mu_}; - queue_.push(op); - lock.unlock(); - cv_.notify_one(); - } + void Submit(Op op); /** * @brief Block the event loop until all ops are finished. In the case of failure, this diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu deleted file mode 100644 index 31c2d394d7cc..000000000000 --- a/src/collective/nccl_device_communicator.cu +++ /dev/null @@ -1,241 +0,0 @@ -/*! - * Copyright 2023 XGBoost contributors - */ -#if defined(XGBOOST_USE_NCCL) -#include "comm.cuh" -#include "nccl_device_communicator.cuh" - -namespace xgboost { -namespace collective { - -NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync, - StringView nccl_path) - : device_ordinal_{device_ordinal}, - needs_sync_{needs_sync}, - world_size_{GetWorldSize()}, - rank_{GetRank()} { - if (device_ordinal_ < 0) { - LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; - } - if (world_size_ == 1) { - return; - } - stub_ = std::make_shared(std::move(nccl_path)); - - std::vector uuids(world_size_ * kUuidLength, 0); - auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; - auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength); - GetCudaUUID(s_this_uuid); - - // TODO(rongou): replace this with allgather. - Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum); - - std::vector> converted(world_size_); - size_t j = 0; - for (size_t i = 0; i < uuids.size(); i += kUuidLength) { - converted[j] = xgboost::common::Span{uuids.data() + i, kUuidLength}; - j++; - } - - auto iter = std::unique(converted.begin(), converted.end()); - auto n_uniques = std::distance(converted.begin(), iter); - - CHECK_EQ(n_uniques, world_size_) - << "Multiple processes within communication group running on same CUDA " - << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; - - nccl_unique_id_ = GetUniqueId(); - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto rc = stub_->CommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_); - CHECK(rc.OK()) << rc.Report(); -} - -NcclDeviceCommunicator::~NcclDeviceCommunicator() { - if (world_size_ == 1) { - return; - } - if (nccl_comm_) { - auto rc = stub_->CommDestroy(nccl_comm_); - CHECK(rc.OK()) << rc.Report(); - } - if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { - LOG(CONSOLE) << "======== NCCL Statistics========"; - LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; - LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576; - } -} - -namespace { -ncclDataType_t GetNcclDataType(DataType const &data_type) { - ncclDataType_t result{ncclInt8}; - switch (data_type) { - case DataType::kInt8: - result = ncclInt8; - break; - case DataType::kUInt8: - result = ncclUint8; - break; - case DataType::kInt32: - result = ncclInt32; - break; - case DataType::kUInt32: - result = ncclUint32; - break; - case DataType::kInt64: - result = ncclInt64; - break; - case DataType::kUInt64: - result = ncclUint64; - break; - case DataType::kFloat: - result = ncclFloat; - break; - case DataType::kDouble: - result = ncclDouble; - break; - default: - LOG(FATAL) << "Unknown data type."; - } - return result; -} - -bool IsBitwiseOp(Operation const &op) { - return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR || - op == Operation::kBitwiseXOR; -} - -ncclRedOp_t GetNcclRedOp(Operation const &op) { - ncclRedOp_t result{ncclMax}; - switch (op) { - case Operation::kMax: - result = ncclMax; - break; - case Operation::kMin: - result = ncclMin; - break; - case Operation::kSum: - result = ncclSum; - break; - default: - LOG(FATAL) << "Unsupported reduce operation."; - } - return result; -} - -template -void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size, - std::size_t size) { - dh::LaunchN(size, [=] __device__(std::size_t idx) { - auto result = device_buffer[idx]; - for (auto rank = 1; rank < world_size; rank++) { - result = func(result, device_buffer[rank * size + idx]); - } - out_buffer[idx] = result; - }); -} -} // anonymous namespace - -void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count, - DataType data_type, Operation op) { - auto const size = count * GetTypeSize(data_type); - dh::caching_device_vector buffer(size * world_size_); - auto *device_buffer = buffer.data().get(); - - // First gather data from all the workers. - auto rc = stub_->Allgather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type), - nccl_comm_, dh::DefaultStream()); - CHECK(rc.OK()) << rc.Report(); - if (needs_sync_) { - dh::DefaultStream().Sync(); - } - - // Then reduce locally. - auto *out_buffer = static_cast(send_receive_buffer); - switch (op) { - case Operation::kBitwiseAND: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and(), world_size_, size); - break; - case Operation::kBitwiseOR: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or(), world_size_, size); - break; - case Operation::kBitwiseXOR: - RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor(), world_size_, size); - break; - default: - LOG(FATAL) << "Not a bitwise reduce operation."; - } -} - -void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count, - DataType data_type, Operation op) { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - if (IsBitwiseOp(op)) { - BitwiseAllReduce(send_receive_buffer, count, data_type, op); - } else { - auto rc = stub_->Allreduce(send_receive_buffer, send_receive_buffer, count, - GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_, - dh::DefaultStream()); - CHECK(rc.OK()) << rc.Report(); - } - allreduce_bytes_ += count * GetTypeSize(data_type); - allreduce_calls_ += 1; -} - -void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer, - std::size_t send_size) { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - auto rc = stub_->Allgather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_, - dh::DefaultStream()); - CHECK(rc.OK()) << rc.Report(); -} - -void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes, - std::vector *segments, - dh::caching_device_vector *receive_buffer) { - if (world_size_ == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - - segments->clear(); - segments->resize(world_size_, 0); - segments->at(rank_) = length_bytes; - Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax); - auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); - receive_buffer->resize(total_bytes); - - size_t offset = 0; - auto rc = Success() << [&] { return stub_->GroupStart(); } << [&] { - for (int32_t i = 0; i < world_size_; ++i) { - size_t as_bytes = segments->at(i); - auto rc = stub_->Broadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, - ncclChar, i, nccl_comm_, dh::DefaultStream()); - if (!rc.OK()) { - return rc; - } - offset += as_bytes; - } - return Success(); - } << [&] { return stub_->GroupEnd(); }; -} - -void NcclDeviceCommunicator::Synchronize() { - if (world_size_ == 1) { - return; - } - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::DefaultStream().Sync(); -} - -} // namespace collective -} // namespace xgboost -#endif diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh deleted file mode 100644 index ef431b571678..000000000000 --- a/src/collective/nccl_device_communicator.cuh +++ /dev/null @@ -1,91 +0,0 @@ -/*! - * Copyright 2022-2023 XGBoost contributors - */ -#pragma once - -#include "../common/device_helpers.cuh" -#include "comm.cuh" -#include "communicator.h" -#include "device_communicator.cuh" -#include "nccl_stub.h" - -namespace xgboost { -namespace collective { - -class NcclDeviceCommunicator : public DeviceCommunicator { - public: - /** - * @brief Construct a new NCCL communicator. - * @param device_ordinal The GPU device id. - * @param needs_sync Whether extra CUDA stream synchronization is needed. - * - * In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes - * a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization - * makes sure that the NCCL kernels are caught up, thus avoiding the deadlock. - * - * The Rabit communicator runs with one process per GPU, so the additional synchronization is not - * needed. The in-memory communicator is used in tests with multiple threads, each thread - * representing a rank/worker, so the additional synchronization is needed to avoid deadlocks. - */ - explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync, StringView nccl_path); - ~NcclDeviceCommunicator() override; - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override; - void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override; - void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, - dh::caching_device_vector *receive_buffer) override; - void Synchronize() override; - - private: - static constexpr std::size_t kUuidLength = - sizeof(std::declval().uuid) / sizeof(uint64_t); - - void GetCudaUUID(xgboost::common::Span const &uuid) const { - cudaDeviceProp prob{}; - dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_)); - std::memcpy(uuid.data(), static_cast(&(prob.uuid)), sizeof(prob.uuid)); - } - - static std::string PrintUUID(xgboost::common::Span const &uuid) { - std::stringstream ss; - for (auto v : uuid) { - ss << std::hex << v; - } - return ss.str(); - } - - /** - * \fn ncclUniqueId GetUniqueId() - * - * \brief Gets the Unique ID from NCCL to be used in setting up interprocess - * communication - * - * \return the Unique ID - */ - ncclUniqueId GetUniqueId() { - static const int kRootRank = 0; - ncclUniqueId id; - if (rank_ == kRootRank) { - auto rc = stub_->GetUniqueId(&id); - CHECK(rc.OK()) << rc.Report(); - } - Broadcast(static_cast(&id), sizeof(ncclUniqueId), static_cast(kRootRank)); - return id; - } - - void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op); - - int const device_ordinal_; - bool const needs_sync_; - int const world_size_; - int const rank_; - ncclComm_t nccl_comm_{}; - std::shared_ptr stub_; - ncclUniqueId nccl_unique_id_{}; - size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. - size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls. -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h deleted file mode 100644 index 2d88fd8024d2..000000000000 --- a/src/collective/noop_communicator.h +++ /dev/null @@ -1,32 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#pragma once -#include - -#include "communicator.h" - -namespace xgboost { -namespace collective { - -/** - * A no-op communicator, used for non-distributed training. - */ -class NoOpCommunicator : public Communicator { - public: - NoOpCommunicator() : Communicator(1, 0) {} - bool IsDistributed() const override { return false; } - bool IsFederated() const override { return false; } - std::string AllGather(std::string_view) override { return {}; } - std::string AllGatherV(std::string_view) override { return {}; } - void AllReduce(void *, std::size_t, DataType, Operation) override {} - void Broadcast(void *, std::size_t, int) override {} - std::string GetProcessorName() override { return {}; } - void Print(const std::string &message) override { LOG(CONSOLE) << message; } - - protected: - void Shutdown() override {} -}; - -} // namespace collective -} // namespace xgboost diff --git a/src/collective/protocol.h b/src/collective/protocol.h index 96edf4e29bcf..2222594033f3 100644 --- a/src/collective/protocol.h +++ b/src/collective/protocol.h @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int32_t @@ -41,23 +41,30 @@ struct Magic { [[nodiscard]] Result Verify(xgboost::collective::TCPSocket* p_sock) { std::int32_t magic{kMagic}; - auto n_bytes = p_sock->SendAll(&magic, sizeof(magic)); - if (n_bytes != sizeof(magic)) { - return Fail("Failed to verify."); - } - - magic = 0; - n_bytes = p_sock->RecvAll(&magic, sizeof(magic)); - if (n_bytes != sizeof(magic)) { - return Fail("Failed to verify."); - } - if (magic != kMagic) { - return xgboost::collective::Fail("Invalid verification number."); - } - return Success(); + std::size_t n_sent{0}; + return Success() << [&] { + return p_sock->SendAll(&magic, sizeof(magic), &n_sent); + } << [&] { + if (n_sent != sizeof(magic)) { + return Fail("Failed to verify."); + } + return Success(); + } << [&] { + magic = 0; + return p_sock->RecvAll(&magic, sizeof(magic), &n_sent); + } << [&] { + if (n_sent != sizeof(magic)) { + return Fail("Failed to verify."); + } + if (magic != kMagic) { + return xgboost::collective::Fail("Invalid verification number."); + } + return Success(); + }; } }; +// Basic commands for communication between workers and the tracker. enum class CMD : std::int32_t { kInvalid = 0, kStart = 1, @@ -84,7 +91,10 @@ struct Connect { [[nodiscard]] Result TrackerRecv(TCPSocket* sock, std::int32_t* world, std::int32_t* rank, std::string* task_id) const { std::string init; - sock->Recv(&init); + auto rc = sock->Recv(&init); + if (!rc.OK()) { + return Fail("Connect protocol failed.", std::move(rc)); + } auto jinit = Json::Load(StringView{init}); *world = get(jinit["world_size"]); *rank = get(jinit["rank"]); @@ -122,9 +132,9 @@ class Start { } [[nodiscard]] Result WorkerRecv(TCPSocket* tracker, std::int32_t* p_world) const { std::string scmd; - auto n_bytes = tracker->Recv(&scmd); - if (n_bytes <= 0) { - return Fail("Failed to recv init command from tracker."); + auto rc = tracker->Recv(&scmd); + if (!rc.OK()) { + return Fail("Failed to recv init command from tracker.", std::move(rc)); } auto jcmd = Json::Load(scmd); auto world = get(jcmd["world_size"]); @@ -132,7 +142,7 @@ class Start { return Fail("Invalid world size."); } *p_world = world; - return Success(); + return rc; } [[nodiscard]] Result TrackerHandle(Json jcmd, std::int32_t* recv_world, std::int32_t world, std::int32_t* p_port, TCPSocket* p_sock, @@ -150,6 +160,7 @@ class Start { } }; +// Protocol for communicating with the tracker for printing message. struct Print { [[nodiscard]] Result WorkerSend(TCPSocket* tracker, std::string msg) const { Json jcmd{Object{}}; @@ -172,6 +183,7 @@ struct Print { } }; +// Protocol for communicating with the tracker during error. struct ErrorCMD { [[nodiscard]] Result WorkerSend(TCPSocket* tracker, Result const& res) const { auto msg = res.Report(); @@ -199,6 +211,7 @@ struct ErrorCMD { } }; +// Protocol for communicating with the tracker during shutdown. struct ShutdownCMD { [[nodiscard]] Result Send(TCPSocket* peer) const { Json jcmd{Object{}}; @@ -211,4 +224,52 @@ struct ShutdownCMD { return Success(); } }; + +// Protocol for communicating with the local error handler during error or shutdown. Only +// one protocol that doesn't have the tracker involved. +struct Error { + constexpr static std::int32_t ShutdownSignal() { return 0; } + constexpr static std::int32_t ErrorSignal() { return -1; } + + [[nodiscard]] Result SignalError(TCPSocket* worker) const { + std::int32_t err{ErrorSignal()}; + std::size_t n_sent{0}; + return Success() << [&] { + return worker->SendAll(&err, sizeof(err), &n_sent); + } << [&] { + if (n_sent == sizeof(err)) { + return Success(); + } + return Fail("Failed to send error signal"); + }; + } + // self is localhost, we are sending the signal to the error handling thread for it to + // close. + [[nodiscard]] Result SignalShutdown(TCPSocket* self) const { + std::int32_t err{ShutdownSignal()}; + std::size_t n_sent{0}; + return Success() << [&] { + return self->SendAll(&err, sizeof(err), &n_sent); + } << [&] { + if (n_sent == sizeof(err)) { + return Success(); + } + return Fail("Failed to send shutdown signal"); + }; + } + // get signal, either for error or for shutdown. + [[nodiscard]] Result RecvSignal(TCPSocket* peer, bool* p_is_error) const { + std::int32_t err{ShutdownSignal()}; + std::size_t n_recv{0}; + return Success() << [&] { + return peer->RecvAll(&err, sizeof(err), &n_recv); + } << [&] { + if (n_recv == sizeof(err)) { + *p_is_error = err == 1; + return Success(); + } + return Fail("Failed to receive error signal."); + }; + } +}; } // namespace xgboost::collective::proto diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h deleted file mode 100644 index 452e9ad9c73c..000000000000 --- a/src/collective/rabit_communicator.h +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2022-2023 by XGBoost contributors - */ -#pragma once -#include - -#include -#include - -#include "communicator-inl.h" -#include "communicator.h" -#include "xgboost/json.h" - -namespace xgboost { -namespace collective { - -class RabitCommunicator : public Communicator { - public: - static Communicator *Create(Json const &config) { - std::vector args_str; - for (auto &items : get(config)) { - switch (items.second.GetValue().Type()) { - case xgboost::Value::ValueKind::kString: { - args_str.push_back(items.first + "=" + get(items.second)); - break; - } - case xgboost::Value::ValueKind::kInteger: { - args_str.push_back(items.first + "=" + std::to_string(get(items.second))); - break; - } - case xgboost::Value::ValueKind::kBoolean: { - if (get(items.second)) { - args_str.push_back(items.first + "=1"); - } else { - args_str.push_back(items.first + "=0"); - } - break; - } - default: - break; - } - } - std::vector args; - for (auto &key_value : args_str) { - args.push_back(&key_value[0]); - } - if (!rabit::Init(static_cast(args.size()), &args[0])) { - LOG(FATAL) << "Failed to initialize Rabit"; - } - return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank()); - } - - RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {} - - bool IsDistributed() const override { return rabit::IsDistributed(); } - - bool IsFederated() const override { return false; } - - std::string AllGather(std::string_view input) override { - auto const per_rank = input.size(); - auto const total_size = per_rank * GetWorldSize(); - auto const index = per_rank * GetRank(); - std::string result(total_size, '\0'); - result.replace(index, per_rank, input); - rabit::Allgather(result.data(), total_size, index, per_rank, per_rank); - return result; - } - - std::string AllGatherV(std::string_view input) override { - auto const size_node_slice = input.size(); - auto const all_sizes = collective::Allgather(size_node_slice); - auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul); - auto const begin_index = - std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul); - auto const size_prev_slice = - GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1]; - - std::string result(total_size, '\0'); - result.replace(begin_index, size_node_slice, input); - rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice); - return result; - } - - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override { - switch (data_type) { - case DataType::kInt8: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kUInt8: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kInt32: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kUInt32: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kInt64: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kUInt64: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kFloat: - DoAllReduce(send_receive_buffer, count, op); - break; - case DataType::kDouble: - DoAllReduce(send_receive_buffer, count, op); - break; - default: - LOG(FATAL) << "Unknown data type"; - } - } - - void Broadcast(void *send_receive_buffer, std::size_t size, int root) override { - rabit::Broadcast(send_receive_buffer, size, root); - } - - std::string GetProcessorName() override { return rabit::GetProcessorName(); } - - void Print(const std::string &message) override { rabit::TrackerPrint(message); } - - protected: - void Shutdown() override { rabit::Finalize(); } - - private: - template ::value> * = nullptr> - void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { - switch (op) { - case Operation::kBitwiseAND: - rabit::Allreduce(static_cast(send_receive_buffer), - count); - break; - case Operation::kBitwiseOR: - rabit::Allreduce(static_cast(send_receive_buffer), count); - break; - case Operation::kBitwiseXOR: - rabit::Allreduce(static_cast(send_receive_buffer), - count); - break; - default: - LOG(FATAL) << "Unknown allreduce operation"; - } - } - - template ::value> * = nullptr> - void DoBitwiseAllReduce(void *, std::size_t, Operation) { - LOG(FATAL) << "Floating point types do not support bitwise operations."; - } - - template - void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) { - switch (op) { - case Operation::kMax: - rabit::Allreduce(static_cast(send_receive_buffer), count); - break; - case Operation::kMin: - rabit::Allreduce(static_cast(send_receive_buffer), count); - break; - case Operation::kSum: - rabit::Allreduce(static_cast(send_receive_buffer), count); - break; - case Operation::kBitwiseAND: - case Operation::kBitwiseOR: - case Operation::kBitwiseXOR: - DoBitwiseAllReduce(send_receive_buffer, count, op); - break; - default: - LOG(FATAL) << "Unknown allreduce operation"; - } - } -}; -} // namespace collective -} // namespace xgboost diff --git a/src/collective/result.cc b/src/collective/result.cc new file mode 100644 index 000000000000..140efa6d8bee --- /dev/null +++ b/src/collective/result.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include "xgboost/collective/result.h" + +#include // for path +#include // for stringstream +#include // for stack + +#include "xgboost/logging.h" + +namespace xgboost::collective { +namespace detail { +[[nodiscard]] std::string ResultImpl::Report() const { + std::stringstream ss; + ss << "\n- " << this->message; + if (this->errc != std::error_code{}) { + ss << " system error:" << this->errc.message(); + } + + auto ptr = prev.get(); + while (ptr) { + ss << "\n- "; + ss << ptr->message; + + if (ptr->errc != std::error_code{}) { + ss << " " << ptr->errc.message(); + } + ptr = ptr->prev.get(); + } + + return ss.str(); +} + +[[nodiscard]] std::error_code ResultImpl::Code() const { + // Find the root error. + std::stack stack; + auto ptr = this; + while (ptr) { + stack.push(ptr); + if (ptr->prev) { + ptr = ptr->prev.get(); + } else { + break; + } + } + while (!stack.empty()) { + auto frame = stack.top(); + stack.pop(); + if (frame->errc != std::error_code{}) { + return frame->errc; + } + } + return std::error_code{}; +} + +void ResultImpl::Concat(std::unique_ptr rhs) { + auto ptr = this; + while (ptr->prev) { + ptr = ptr->prev.get(); + } + ptr->prev = std::move(rhs); +} + +std::string MakeMsg(std::string&& msg, char const* file, std::int32_t line) { + dmlc::DateLogger logger; + if (file && line != -1) { + auto name = std::filesystem::path{ file }.filename(); + return "[" + name.string() + ":" + std::to_string(line) + "|" + logger.HumanDate() + + "]: " + std::forward(msg); + } + return std::string{"["} + logger.HumanDate() + "]" + std::forward(msg); // NOLINT +} +} // namespace detail + +void SafeColl(Result const& rc) { + if (!rc.OK()) { + LOG(FATAL) << rc.Report(); + } +} +} // namespace xgboost::collective diff --git a/src/collective/socket.cc b/src/collective/socket.cc index 43da366bd7de..99b02f665f10 100644 --- a/src/collective/socket.cc +++ b/src/collective/socket.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include "xgboost/collective/socket.h" @@ -8,7 +8,8 @@ #include // std::int32_t #include // std::memcpy, std::memset #include // for path -#include // std::error_code, std::system_category +#include // for error_code, system_category +#include // for sleep_for #include "rabit/internal/socket.h" // for PollHelper #include "xgboost/collective/result.h" // for Result @@ -59,20 +60,46 @@ std::size_t TCPSocket::Send(StringView str) { CHECK(!this->IsClosed()); CHECK_LT(str.size(), std::numeric_limits::max()); std::int32_t len = static_cast(str.size()); - CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length."; - auto bytes = this->SendAll(str.c_str(), str.size()); - CHECK_EQ(bytes, str.size()) << "Failed to send string."; - return bytes; + std::size_t n_bytes{0}; + auto rc = Success() << [&] { + return this->SendAll(&len, sizeof(len), &n_bytes); + } << [&] { + if (n_bytes != sizeof(len)) { + return Fail("Failed to send string length."); + } + return Success(); + } << [&] { + return this->SendAll(str.c_str(), str.size(), &n_bytes); + } << [&] { + if (n_bytes != str.size()) { + return Fail("Failed to send string."); + } + return Success(); + }; + SafeColl(rc); + return n_bytes; } -std::size_t TCPSocket::Recv(std::string *p_str) { +[[nodiscard]] Result TCPSocket::Recv(std::string *p_str) { CHECK(!this->IsClosed()); std::int32_t len; - CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length."; - p_str->resize(len); - auto bytes = this->RecvAll(&(*p_str)[0], len); - CHECK_EQ(bytes, len) << "Failed to recv string."; - return bytes; + std::size_t n_bytes{0}; + return Success() << [&] { + return this->RecvAll(&len, sizeof(len), &n_bytes); + } << [&] { + if (n_bytes != sizeof(len)) { + return Fail("Failed to recv string length."); + } + return Success(); + } << [&] { + p_str->resize(len); + return this->RecvAll(&(*p_str)[0], len, &n_bytes); + } << [&] { + if (static_cast>(n_bytes) != len) { + return Fail("Failed to recv string."); + } + return Success(); + }; } [[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, @@ -110,11 +137,7 @@ std::size_t TCPSocket::Recv(std::string *p_str) { for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) { if (attempt > 0) { LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time."; -#if defined(_MSC_VER) || defined(__MINGW32__) - Sleep(attempt << 1); -#else - sleep(attempt << 1); -#endif + std::this_thread::sleep_for(std::chrono::seconds{attempt << 1}); } auto rc = connect(conn.Handle(), addr_handle, addr_len); @@ -158,8 +181,8 @@ std::size_t TCPSocket::Recv(std::string *p_str) { std::stringstream ss; ss << "Failed to connect to " << host << ":" << port; - conn.Close(); - return Fail(ss.str(), std::move(last_error)); + auto close_rc = conn.Close(); + return Fail(ss.str(), std::move(close_rc) + std::move(last_error)); } [[nodiscard]] Result GetHostName(std::string *p_out) { diff --git a/src/collective/tracker.cc b/src/collective/tracker.cc index 88c51d8a909f..f4c07c5d145c 100644 --- a/src/collective/tracker.cc +++ b/src/collective/tracker.cc @@ -1,6 +1,7 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ +#include "rabit/internal/socket.h" #if defined(__unix__) || defined(__APPLE__) #include // gethostbyname #include // socket, AF_INET6, AF_INET, connect, getsockname @@ -27,15 +28,23 @@ #include "tracker.h" #include "xgboost/collective/result.h" // for Result, Fail, Success #include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ... -#include "xgboost/json.h" +#include "xgboost/json.h" // for Json namespace xgboost::collective { + Tracker::Tracker(Json const& config) - : n_workers_{static_cast( - RequiredArg(config, "n_workers", __func__))}, + : sortby_{static_cast( + OptionalArg(config, "sortby", static_cast(SortBy::kHost)))}, + n_workers_{ + static_cast(RequiredArg(config, "n_workers", __func__))}, port_{static_cast(OptionalArg(config, "port", Integer::Int{0}))}, - timeout_{std::chrono::seconds{OptionalArg( - config, "timeout", static_cast(collective::DefaultTimeoutSec()))}} {} + timeout_{std::chrono::seconds{ + OptionalArg(config, "timeout", static_cast(0))}} { + using std::chrono_literals::operator""s; + // Some old configurations in JVM for the scala implementation (removed) use 0 to + // indicate blocking. We continue that convention here. + timeout_ = (timeout_ == 0s) ? -1s : timeout_; +} Result Tracker::WaitUntilReady() const { using namespace std::chrono_literals; // NOLINT @@ -46,7 +55,7 @@ Result Tracker::WaitUntilReady() const { timer.Start(); while (!this->Ready()) { auto ela = timer.Duration().count(); - if (ela > this->Timeout().count()) { + if (HasTimeout(this->Timeout()) && ela > this->Timeout().count()) { return Fail("Failed to start tracker, timeout:" + std::to_string(this->Timeout().count()) + " seconds."); } @@ -56,20 +65,25 @@ Result Tracker::WaitUntilReady() const { return Success(); } -RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr) +RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr) : sock_{std::move(sock)} { std::int32_t rank{0}; Json jcmd; std::int32_t port{0}; - rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] { + rc_ = Success() << [&] { + return proto::Magic{}.Verify(&sock_); + } << [&] { return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_); } << [&] { std::string cmd; - sock_.Recv(&cmd); + auto rc = sock_.Recv(&cmd); + if (!rc.OK()) { + return rc; + } jcmd = Json::Load(StringView{cmd}); cmd_ = static_cast(get(jcmd["cmd"])); - return Success(); + return rc; } << [&] { if (cmd_ == proto::CMD::kStart) { proto::Start start; @@ -83,28 +97,37 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA } return Success(); } << [&] { - auto host = addr.Addr(); - info_ = proto::PeerInfo{host, port, rank}; + if (addr.IsV4()) { + auto host = addr.V4().Addr(); + info_ = proto::PeerInfo{host, port, rank}; + } else { + auto host = addr.V6().Addr(); + info_ = proto::PeerInfo{host, port, rank}; + } return Success(); }; } RabitTracker::RabitTracker(Json const& config) : Tracker{config} { std::string self; - auto rc = collective::GetHostAddress(&self); - auto host = OptionalArg(config, "host", self); - - host_ = host; - listener_ = TCPSocket::Create(SockDomain::kV4); - rc = listener_.Bind(host, &this->port_); - CHECK(rc.OK()) << rc.Report(); - listener_.Listen(); + auto rc = Success() << [&] { + return collective::GetHostAddress(&self); + } << [&] { + host_ = OptionalArg(config, "host", self); + + auto addr = MakeSockAddress(xgboost::StringView{host_}, 0); + listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6); + return listener_.Bind(host_, &this->port_); + } << [&] { + return listener_.Listen(); + }; + SafeColl(rc); } Result RabitTracker::Bootstrap(std::vector* p_workers) { auto& workers = *p_workers; - std::sort(workers.begin(), workers.end(), WorkerCmp{}); + std::sort(workers.begin(), workers.end(), WorkerCmp{this->sortby_}); std::vector bootstrap_threads; for (std::int32_t r = 0; r < n_workers_; ++r) { @@ -211,9 +234,13 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { // // retry is set to 1, just let the worker timeout or error. Otherwise the // tracker and the worker might be waiting for each other. - auto rc = Connect(w.first, w.second, 1, timeout_, &out); + auto rc = Success() << [&] { + return Connect(w.first, w.second, 1, timeout_, &out); + } << [&] { + return proto::Error{}.SignalError(&out); + }; if (!rc.OK()) { - return Fail("Failed to inform workers to stop."); + return Fail("Failed to inform worker:" + w.first + " for error.", std::move(rc)); } } return Success(); @@ -222,18 +249,45 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { return std::async(std::launch::async, [this, handle_error] { State state{this->n_workers_}; + auto select_accept = [&](TCPSocket* sock, auto* addr) { + // accept with poll so that we can enable timeout and interruption. + rabit::utils::PollHelper poll; + auto rc = Success() << [&] { + std::lock_guard lock{listener_mu_}; + return listener_.NonBlocking(true); + } << [&] { + { + std::lock_guard lock{listener_mu_}; + poll.WatchRead(listener_); + } + if (state.running) { + // Don't timeout if the communicator group is up and running. + return poll.Poll(std::chrono::seconds{-1}); + } else { + // Have timeout for workers to bootstrap. + return poll.Poll(timeout_); + } + } << [&] { + // this->Stop() closes the socket with a lock. Therefore, when the accept returns + // due to shutdown, the state is still valid (closed). + return listener_.Accept(sock, addr); + }; + return rc; + }; + while (state.ShouldContinue()) { TCPSocket sock; - SockAddrV4 addr; + SockAddress addr; this->ready_ = true; - auto rc = listener_.Accept(&sock, &addr); + auto rc = select_accept(&sock, &addr); if (!rc.OK()) { - return Fail("Failed to accept connection.", std::move(rc)); + return Fail("Failed to accept connection.", this->Stop() + std::move(rc)); } auto worker = WorkerProxy{n_workers_, std::move(sock), std::move(addr)}; if (!worker.Status().OK()) { - return Fail("Failed to initialize worker proxy.", std::move(worker.Status())); + LOG(WARNING) << "Failed to initialize worker proxy." << worker.Status().Report(); + continue; } switch (worker.Command()) { case proto::CMD::kStart: { @@ -243,7 +297,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { state.Error(); rc = handle_error(worker); if (!rc.OK()) { - return Fail("Failed to handle abort.", std::move(rc)); + return Fail("Failed to handle abort.", this->Stop() + std::move(rc)); } } @@ -253,7 +307,7 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { state.Bootstrap(); } if (!rc.OK()) { - return rc; + return this->Stop() + std::move(rc); } continue; } @@ -280,25 +334,43 @@ Result RabitTracker::Bootstrap(std::vector* p_workers) { } case proto::CMD::kInvalid: default: { - return Fail("Invalid command received."); + return Fail("Invalid command received.", this->Stop()); } } } - ready_ = false; - return Success(); + return this->Stop(); }); } [[nodiscard]] Json RabitTracker::WorkerArgs() const { auto rc = this->WaitUntilReady(); - CHECK(rc.OK()) << rc.Report(); + SafeColl(rc); Json args{Object{}}; - args["DMLC_TRACKER_URI"] = String{host_}; - args["DMLC_TRACKER_PORT"] = this->Port(); + args["dmlc_tracker_uri"] = String{host_}; + args["dmlc_tracker_port"] = this->Port(); return args; } +[[nodiscard]] Result RabitTracker::Stop() { + if (!this->Ready()) { + return Success(); + } + + ready_ = false; + std::lock_guard lock{listener_mu_}; + if (this->listener_.IsClosed()) { + return Success(); + } + + return Success() << [&] { + // This should have the effect of stopping the `accept` call. + return this->listener_.Shutdown(); + } << [&] { + return listener_.Close(); + }; +} + [[nodiscard]] Result GetHostAddress(std::string* out) { auto rc = GetHostName(out); if (!rc.OK()) { diff --git a/src/collective/tracker.h b/src/collective/tracker.h index f336a82f9ee5..b81cf655964b 100644 --- a/src/collective/tracker.h +++ b/src/collective/tracker.h @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for seconds @@ -15,6 +15,7 @@ #include "xgboost/json.h" // for Json namespace xgboost::collective { +inline bool HasTimeout(std::chrono::seconds timeout) { return timeout.count() > 0; } /** * * @brief Implementation of RABIT tracker. @@ -36,18 +37,28 @@ namespace xgboost::collective { * signal an error to the tracker and the tracker will notify other workers. */ class Tracker { + public: + enum class SortBy : std::int8_t { + kHost = 0, + kTask = 1, + }; + + protected: + // How to sort the workers, either by host name or by task ID. When using a multi-GPU + // setting, multiple workers can occupy the same host, in which case one should sort + // workers by task. Due to compatibility reason, the task ID is not always available, so + // we use host as the default. + SortBy sortby_; + protected: std::int32_t n_workers_{0}; std::int32_t port_{-1}; - std::chrono::seconds timeout_{0}; + std::chrono::seconds timeout_{-1}; std::atomic ready_{false}; public: explicit Tracker(Json const& config); - Tracker(std::int32_t n_worders, std::int32_t port, std::chrono::seconds timeout) - : n_workers_{n_worders}, port_{port}, timeout_{timeout} {} - - virtual ~Tracker() noexcept(false){}; // NOLINT + virtual ~Tracker() = default; [[nodiscard]] Result WaitUntilReady() const; @@ -59,6 +70,11 @@ class Tracker { * @brief Flag to indicate whether the server is running. */ [[nodiscard]] bool Ready() const { return ready_; } + /** + * @brief Shutdown the tracker, cannot be restarted again. Useful when the tracker hangs while + * calling accept. + */ + virtual Result Stop() { return Success(); } }; class RabitTracker : public Tracker { @@ -76,7 +92,7 @@ class RabitTracker : public Tracker { Result rc_; public: - explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr); + explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr); WorkerProxy(WorkerProxy const& that) = delete; WorkerProxy(WorkerProxy&& that) = default; WorkerProxy& operator=(WorkerProxy const&) = delete; @@ -96,11 +112,14 @@ class RabitTracker : public Tracker { void Send(StringView value) { this->sock_.Send(value); } }; - // provide an ordering for workers, this helps us get deterministic topology. + // Provide an ordering for workers, this helps us get deterministic topology. struct WorkerCmp { + SortBy sortby; + explicit WorkerCmp(SortBy sortby) : sortby{sortby} {} + [[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) { - auto const& lh = lhs.Host(); - auto const& rh = rhs.Host(); + auto const& lh = sortby == Tracker::SortBy::kHost ? lhs.Host() : lhs.TaskID(); + auto const& rh = sortby == Tracker::SortBy::kHost ? rhs.Host() : rhs.TaskID(); if (lh != rh) { return lh < rh; @@ -114,28 +133,22 @@ class RabitTracker : public Tracker { // record for how to reach out to workers if error happens. std::vector> worker_error_handles_; // listening socket for incoming workers. - // - // At the moment, the listener calls accept without first polling. We can add an - // additional unix domain socket to allow cancelling the accept. TCPSocket listener_; + // mutex for protecting the listener, used to prevent race when it's listening while + // another thread tries to shut it down. + std::mutex listener_mu_; Result Bootstrap(std::vector* p_workers); public: - explicit RabitTracker(StringView host, std::int32_t n_worders, std::int32_t port, - std::chrono::seconds timeout) - : Tracker{n_worders, port, timeout}, host_{host.c_str(), host.size()} { - listener_ = TCPSocket::Create(SockDomain::kV4); - auto rc = listener_.Bind(host, &this->port_); - CHECK(rc.OK()) << rc.Report(); - listener_.Listen(); - } - explicit RabitTracker(Json const& config); - ~RabitTracker() noexcept(false) override = default; + ~RabitTracker() override = default; std::future Run() override; [[nodiscard]] Json WorkerArgs() const override; + // Stop the tracker without waiting. This is to prevent the tracker from hanging when + // one of the workers failes to start. + [[nodiscard]] Result Stop() override; }; // Prob the public IP address of the host, need a better method. diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 440f3c0a87c8..843cee80fbc9 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -72,7 +72,7 @@ class SparseColumnIter : public Column { public: SparseColumnIter(common::Span index, bst_bin_t least_bin_idx, - common::Span row_ind, bst_row_t first_row_idx) + common::Span row_ind, bst_idx_t first_row_idx) : Base{index, least_bin_idx}, row_ind_(row_ind) { // first_row_id is the first row in the leaf partition const size_t* row_data = RowIndices(); @@ -301,7 +301,7 @@ class ColumnMatrix { } template - auto SparseColumn(bst_feature_t fidx, bst_row_t first_row_idx) const { + auto SparseColumn(bst_feature_t fidx, bst_idx_t first_row_idx) const { const size_t feature_offset = feature_offsets_[fidx]; // to get right place for certain feature const size_t column_size = feature_offsets_[fidx + 1] - feature_offset; common::Span bin_index = { @@ -325,7 +325,7 @@ class ColumnMatrix { // all columns are dense column and has no missing value // FIXME(jiamingy): We don't need a column matrix if there's no missing value. template - void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples, + void SetIndexNoMissing(bst_idx_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples, const size_t n_features, int32_t n_threads) { missing_.GrowTo(feature_offsets_[n_features], false); diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 46f76c41589d..f4fce42f84f8 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -12,18 +12,15 @@ #include // make_transform_output_iterator #include #include -#include #include #include #include #include #include -#include #include // for size_t #include #include -#include #include #include #include @@ -31,7 +28,6 @@ #include "../collective/communicator-inl.h" #include "common.h" -#include "xgboost/global_config.h" #include "xgboost/host_device_vector.h" #include "xgboost/logging.h" #include "xgboost/span.h" @@ -302,21 +298,22 @@ class MemoryLogger { void RegisterAllocation(void *ptr, size_t n) { device_allocations[ptr] = n; currently_allocated_bytes += n; - peak_allocated_bytes = - std::max(peak_allocated_bytes, currently_allocated_bytes); + peak_allocated_bytes = std::max(peak_allocated_bytes, currently_allocated_bytes); num_allocations++; CHECK_GT(num_allocations, num_deallocations); } void RegisterDeallocation(void *ptr, size_t n, int current_device) { auto itr = device_allocations.find(ptr); if (itr == device_allocations.end()) { - LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " - << current_device << " that was never allocated "; + LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device " << current_device + << " that was never allocated\n" + << dmlc::StackTrace(); + } else { + num_deallocations++; + CHECK_LE(num_deallocations, num_allocations); + currently_allocated_bytes -= itr->second; + device_allocations.erase(itr); } - num_deallocations++; - CHECK_LE(num_deallocations, num_allocations); - currently_allocated_bytes -= itr->second; - device_allocations.erase(itr); } }; DeviceStats stats_; diff --git a/src/common/error_msg.cc b/src/common/error_msg.cc index 8871c1a1d697..cdbe5ebf6652 100644 --- a/src/common/error_msg.cc +++ b/src/common/error_msg.cc @@ -11,7 +11,7 @@ #include "xgboost/logging.h" namespace xgboost::error { -std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) { +[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement) { std::stringstream ss; ss << "`" << old << "` is deprecated since" << since << ", use `" << replacement << "` instead."; return ss.str(); diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 7264c3532d65..67114320b7d3 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -89,7 +89,7 @@ void WarnDeprecatedGPUId(); void WarnEmptyDataset(); -std::string DeprecatedFunc(StringView old, StringView since, StringView replacement); +[[nodiscard]] std::string DeprecatedFunc(StringView old, StringView since, StringView replacement); constexpr StringView InvalidCUDAOrdinal() { return "Invalid device. `device` is required to be CUDA and there must be at least one GPU " diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index f101247920a4..9b703a3fa13a 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -34,7 +34,7 @@ HistogramCuts SketchOnDMatrix(Context const *ctx, DMatrix *m, bst_bin_t max_bins HistogramCuts out; auto const &info = m->Info(); auto n_threads = ctx->Threads(); - std::vector reduced(info.num_col_, 0); + std::vector reduced(info.num_col_, 0); for (auto const &page : m->GetBatches()) { auto const &entries_per_column = CalcColumnSize(data::SparsePageAdapterBatch{page.GetView()}, info.num_col_, n_threads, @@ -209,10 +209,10 @@ void RowsWiseBuildHistKernel(Span gpair, CHECK(offsets); } - auto get_row_ptr = [&](bst_row_t ridx) { + auto get_row_ptr = [&](bst_idx_t ridx) { return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; }; - auto get_rid = [&](bst_row_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); }; + auto get_rid = [&](bst_idx_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); }; const size_t n_features = get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]); @@ -275,10 +275,10 @@ void ColsWiseBuildHistKernel(Span gpair, auto const &row_ptr = gmat.row_ptr.data(); auto base_rowid = gmat.base_rowid; const uint32_t *offsets = gmat.index.Offset(); - auto get_row_ptr = [&](bst_row_t ridx) { + auto get_row_ptr = [&](bst_idx_t ridx) { return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; }; - auto get_rid = [&](bst_row_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); }; + auto get_rid = [&](bst_idx_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); }; const size_t n_features = gmat.cut.Ptrs().size() - 1; const size_t n_columns = n_features; diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index fbe6356bf501..39f310ebb66a 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -13,8 +13,6 @@ #include #include // for size_t -#include -#include #include #include @@ -39,7 +37,7 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) { return std::min(num_cuts, num_rows); } -size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns, +size_t RequiredSampleCuts(bst_idx_t num_rows, bst_feature_t num_columns, size_t max_bins, size_t nnz) { auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows); auto if_dense = num_columns * per_column; @@ -47,7 +45,7 @@ size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns, return result; } -size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, +size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz, size_t num_bins, bool with_weights) { size_t peak = 0; // 0. Allocate cut pointer in quantile container by increasing: n_columns + 1 @@ -85,7 +83,7 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, return peak; } -size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_rows, +size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_rows, bst_feature_t columns, size_t nnz, int device, size_t num_cuts, bool has_weight) { auto constexpr kIntMax = static_cast(std::numeric_limits::max()); @@ -123,7 +121,7 @@ void SortByWeight(dh::device_vector* weights, dh::device_vector* s [=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; }); } -void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span d_cuts_ptr, +void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span d_cuts_ptr, dh::device_vector* p_sorted_entries, dh::device_vector* p_sorted_weights, dh::caching_device_vector* p_column_sizes_scan) { @@ -210,7 +208,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c sorted_entries = dh::device_vector(h_data.begin() + begin, h_data.begin() + end); } - bst_row_t base_rowid = page.base_rowid; + bst_idx_t base_rowid = page.base_rowid; dh::device_vector entry_weight; auto cuctx = ctx->CUDACtx(); diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 3cd13030ef40..cf1043ddb399 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -8,6 +8,7 @@ #define COMMON_HIST_UTIL_CUH_ #include +#include // for sort #include // for size_t @@ -186,7 +187,7 @@ inline size_t constexpr BytesPerElement(bool has_weight) { * directly if it's not 0. */ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, - bst_row_t num_rows, bst_feature_t columns, + bst_idx_t num_rows, bst_feature_t columns, size_t nnz, int device, size_t num_cuts, bool has_weight); @@ -209,7 +210,7 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows); * * \return The estimated bytes */ -size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, +size_t RequiredMemory(bst_idx_t num_rows, bst_feature_t num_columns, size_t nnz, size_t num_bins, bool with_weights); // Count the valid entries in each column and copy them out. @@ -240,7 +241,7 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, Ran void SortByWeight(dh::device_vector* weights, dh::device_vector* sorted_entries); -void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span d_cuts_ptr, +void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span d_cuts_ptr, dh::device_vector* p_sorted_entries, dh::device_vector* p_sorted_weights, dh::caching_device_vector* p_column_sizes_scan); diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index a7a996c6c1ff..f4973c0428f0 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -178,7 +178,7 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; -template class HostDeviceVector; // bst_row_t +template class HostDeviceVector; template class HostDeviceVector; // bst_feature_t #if defined(__APPLE__) || defined(__EMSCRIPTEN__) diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 4933a4b11344..99448df21b7e 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -6,7 +6,6 @@ #include #include -#include #include "xgboost/data.h" #include "xgboost/host_device_vector.h" @@ -412,7 +411,7 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; -template class HostDeviceVector; // bst_row_t +template class HostDeviceVector; template class HostDeviceVector; // bst_feature_t template class HostDeviceVector; template class HostDeviceVector; diff --git a/src/common/io.h b/src/common/io.h index 5e9d275829ae..d2fcc9f92471 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023, XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file io.h * \brief general stream interface for serialization, I/O * \author Tianqi Chen @@ -8,7 +8,6 @@ #define XGBOOST_COMMON_IO_H_ #include -#include // for MemoryFixSizeBuffer, MemoryBufferStream #include // for min, fill_n, copy_n #include // for array @@ -23,12 +22,99 @@ #include // for move #include // for vector -#include "common.h" +#include "common.h" // for DivRoundUp #include "xgboost/string_view.h" // for StringView namespace xgboost::common { -using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer; -using MemoryBufferStream = rabit::utils::MemoryBufferStream; +struct MemoryFixSizeBuffer : public dmlc::SeekStream { + public: + // similar to SEEK_END in libc + static std::size_t constexpr kSeekEnd = std::numeric_limits::max(); + + public: + /** + * @brief Ctor + * + * @param p_buffer Pointer to the source buffer with size `buffer_size`. + * @param buffer_size Size of the source buffer + */ + MemoryFixSizeBuffer(void *p_buffer, std::size_t buffer_size) + : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) {} + ~MemoryFixSizeBuffer() override = default; + + std::size_t Read(void *ptr, std::size_t size) override { + std::size_t nread = std::min(buffer_size_ - curr_ptr_, size); + if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + void Write(const void *ptr, std::size_t size) override { + if (size == 0) return; + CHECK_LE(curr_ptr_ + size, buffer_size_); + std::memcpy(p_buffer_ + curr_ptr_, ptr, size); + curr_ptr_ += size; + } + void Seek(std::size_t pos) override { + if (pos == kSeekEnd) { + curr_ptr_ = buffer_size_; + } else { + curr_ptr_ = static_cast(pos); + } + } + /** + * @brief Current position in the buffer (stream). + */ + std::size_t Tell() override { return curr_ptr_; } + [[nodiscard]] virtual bool AtEnd() const { return curr_ptr_ == buffer_size_; } + + protected: + /*! \brief in memory buffer */ + char *p_buffer_{nullptr}; + /*! \brief current pointer */ + std::size_t buffer_size_{0}; + /*! \brief current pointer */ + std::size_t curr_ptr_{0}; +}; + +/*! \brief a in memory buffer that can be read and write as stream interface */ +struct MemoryBufferStream : public dmlc::SeekStream { + public: + explicit MemoryBufferStream(std::string *p_buffer) + : p_buffer_(p_buffer) { + curr_ptr_ = 0; + } + ~MemoryBufferStream() override = default; + size_t Read(void *ptr, size_t size) override { + CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length"; + size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); + if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + void Write(const void *ptr, size_t size) override { + if (size == 0) return; + if (curr_ptr_ + size > p_buffer_->length()) { + p_buffer_->resize(curr_ptr_+size); + } + std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); + curr_ptr_ += size; + } + void Seek(size_t pos) override { + curr_ptr_ = static_cast(pos); + } + size_t Tell() override { + return curr_ptr_; + } + virtual bool AtEnd() const { + return curr_ptr_ == p_buffer_->length(); + } + + private: + /*! \brief in memory buffer */ + std::string *p_buffer_; + /*! \brief current pointer */ + size_t curr_ptr_; +}; // class MemoryBufferStream /*! * \brief Input stream that support additional PeekRead operation, diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index 9ef36598d9de..21fad2dc0b4a 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -13,15 +13,14 @@ #include "xgboost/context.h" // for Context #include "xgboost/linalg.h" // for TensorView -namespace xgboost { -namespace linalg { +namespace xgboost::linalg { namespace cuda_impl { // Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended // lambda inside constexpr if template struct ElementWiseImpl { template - void operator()(linalg::TensorView t, Fn&& fn, cudaStream_t s) { + void operator()(TensorView t, Fn&& fn, cudaStream_t s) { static_assert(D > 1); dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable { std::apply(fn, linalg::UnravelIndex(i, t.Shape())); @@ -32,36 +31,59 @@ struct ElementWiseImpl { template struct ElementWiseImpl { template - void operator()(linalg::TensorView t, Fn&& fn, cudaStream_t s) { + void operator()(TensorView t, Fn&& fn, cudaStream_t s) { dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); }); } }; template -void ElementWiseKernel(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { +void ElementWiseKernel(TensorView t, Fn&& fn, cudaStream_t s = nullptr) { dh::safe_cuda(cudaSetDevice(t.Device().ordinal)); cuda_impl::ElementWiseImpl{}(t, fn, s); } } // namespace cuda_impl template -void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { +void ElementWiseTransformDevice(TensorView t, Fn&& fn, cudaStream_t s = nullptr) { if (t.Contiguous()) { auto ptr = t.Values().data(); dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); }); } else { dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { - T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); + T& v = detail::Apply(t, UnravelIndex(i, t.Shape())); v = fn(i, v); }); } } template -void ElementWiseKernel(Context const* ctx, linalg::TensorView t, Fn&& fn) { +void ElementWiseKernel(Context const* ctx, TensorView t, Fn&& fn) { ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn); } -} // namespace linalg -} // namespace xgboost + +namespace detail { +template +struct IterOp { + TensorView v; + XGBOOST_DEVICE T& operator()(std::size_t i) { + return detail::Apply(v, UnravelIndex(i, v.Shape())); + } +}; +} // namespace detail + +// naming: thrust begin +// returns a thrust iterator for a tensor view. +template +auto tcbegin(TensorView v) { // NOLINT + return dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + detail::IterOp>, kDim>{v}); +} + +template +auto tcend(TensorView v) { // NOLINT + return tcbegin(v) + v.Size(); +} +} // namespace xgboost::linalg #endif // XGBOOST_COMMON_LINALG_OP_CUH_ diff --git a/src/common/quantile.cc b/src/common/quantile.cc index e293ac739298..232bb7800c5b 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -4,6 +4,7 @@ #include "quantile.h" #include +#include // for partial_sum #include #include "../collective/aggregator.h" @@ -14,7 +15,7 @@ namespace xgboost::common { template SketchContainerImpl::SketchContainerImpl(Context const *ctx, - std::vector columns_size, + std::vector columns_size, int32_t max_bins, Span feature_types, bool use_group) @@ -115,19 +116,19 @@ INSTANTIATE(ColumnarAdapterBatch) namespace { /** - * \brief A view over gathered sketch values. + * @brief A view over gathered sketch values. */ template struct QuantileAllreduce { common::Span global_values; - common::Span worker_indptr; - common::Span feature_indptr; - size_t n_features{0}; + common::Span worker_indptr; + common::Span feature_indptr; + bst_feature_t n_features{0}; /** - * \brief Get sketch values of the a feature from a worker. + * @brief Get sketch values of the a feature from a worker. * - * \param rank rank of target worker - * \param fidx feature idx + * @param rank rank of target worker + * @param fidx feature idx */ [[nodiscard]] auto Values(int32_t rank, bst_feature_t fidx) const { // get span for worker @@ -147,16 +148,16 @@ template void SketchContainerImpl::GatherSketchInfo( Context const *ctx, MetaInfo const &info, std::vector const &reduced, - std::vector *p_worker_segments, std::vector *p_sketches_scan, + std::vector *p_worker_segments, std::vector *p_sketches_scan, std::vector *p_global_sketches) { auto &worker_segments = *p_worker_segments; worker_segments.resize(1, 0); auto world = collective::GetWorldSize(); auto rank = collective::GetRank(); - auto n_columns = sketches_.size(); + bst_feature_t n_columns = sketches_.size(); // get the size of each feature. - std::vector sketch_size; + std::vector sketch_size; for (size_t i = 0; i < reduced.size(); ++i) { if (IsCat(feature_types_, i)) { sketch_size.push_back(0); @@ -164,8 +165,8 @@ void SketchContainerImpl::GatherSketchInfo( sketch_size.push_back(reduced[i].size); } } - // turn the size into CSC indptr - std::vector &sketches_scan = *p_sketches_scan; + // Turn the size into CSC indptr + std::vector &sketches_scan = *p_sketches_scan; sketches_scan.resize((n_columns + 1) * world, 0); size_t beg_scan = rank * (n_columns + 1); // starting storage for current worker. std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1); @@ -173,7 +174,10 @@ void SketchContainerImpl::GatherSketchInfo( // Gather all column pointers auto rc = collective::GlobalSum(ctx, info, linalg::MakeVec(sketches_scan.data(), sketches_scan.size())); - collective::SafeColl(rc); + if (!rc.OK()) { + collective::SafeColl(collective::Fail("Failed to get sketch scan.", std::move(rc))); + } + for (int32_t i = 0; i < world; ++i) { size_t back = (i + 1) * (n_columns + 1) - 1; auto n_entries = sketches_scan.at(back); @@ -205,7 +209,9 @@ void SketchContainerImpl::GatherSketchInfo( ctx, info, linalg::MakeVec(reinterpret_cast(global_sketches.data()), global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float))); - collective::SafeColl(rc); + if (!rc.OK()) { + collective::SafeColl(collective::Fail("Failed to get sketch.", std::move(rc))); + } } template @@ -226,7 +232,7 @@ void SketchContainerImpl::AllreduceCategories(Context const* ctx, Meta CHECK_EQ(feature_ptr.front(), 0); // gather all feature ptrs from workers - std::vector global_feat_ptrs(feature_ptr.size() * world_size, 0); + std::vector global_feat_ptrs(feature_ptr.size() * world_size, 0); size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin); auto rc = collective::GlobalSum( @@ -241,7 +247,7 @@ void SketchContainerImpl::AllreduceCategories(Context const* ctx, Meta } // indptr for indexing workers - std::vector global_worker_ptr(world_size + 1, 0); + std::vector global_worker_ptr(world_size + 1, 0); global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr rc = collective::GlobalSum(ctx, info, linalg::MakeVec(global_worker_ptr.data(), global_worker_ptr.size())); @@ -259,7 +265,7 @@ void SketchContainerImpl::AllreduceCategories(Context const* ctx, Meta rc = collective::GlobalSum(ctx, info, linalg::MakeVec(global_categories.data(), global_categories.size())); QuantileAllreduce allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs, - categories_.size()}; + static_cast(categories_.size())}; ParallelFor(categories_.size(), n_threads_, [&](auto fidx) { if (!IsCat(feature_types_, fidx)) { return; @@ -284,8 +290,9 @@ void SketchContainerImpl::AllReduce( std::vector *p_reduced, std::vector *p_num_cuts) { monitor_.Start(__func__); - size_t n_columns = sketches_.size(); - collective::Allreduce(&n_columns, 1); + bst_feature_t n_columns = sketches_.size(); + auto rc = collective::Allreduce(ctx, &n_columns, collective::Op::kMax); + collective::SafeColl(rc); CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers"; AllreduceCategories(ctx, info); @@ -298,14 +305,14 @@ void SketchContainerImpl::AllReduce( reduced.resize(sketches_.size()); // Prune the intermediate num cuts for synchronization. - std::vector global_column_size(columns_size_); - auto rc = collective::GlobalSum( - ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size())); + std::vector global_column_size(columns_size_); + rc = collective::GlobalSum(ctx, info, + linalg::MakeVec(global_column_size.data(), global_column_size.size())); collective::SafeColl(rc); ParallelFor(sketches_.size(), n_threads_, [&](size_t i) { int32_t intermediate_num_cuts = static_cast( - std::min(global_column_size[i], static_cast(max_bins_ * WQSketch::kFactor))); + std::min(global_column_size[i], static_cast(max_bins_ * WQSketch::kFactor))); if (global_column_size[i] == 0) { return; } @@ -327,8 +334,8 @@ void SketchContainerImpl::AllReduce( return; } - std::vector worker_segments(1, 0); // CSC pointer to sketches. - std::vector sketches_scan((n_columns + 1) * world, 0); + std::vector worker_segments(1, 0); // CSC pointer to sketches. + std::vector sketches_scan((n_columns + 1) * world, 0); std::vector global_sketches; this->GatherSketchInfo(ctx, info, reduced, &worker_segments, &sketches_scan, &global_sketches); @@ -361,14 +368,14 @@ void SketchContainerImpl::AllReduce( } template -bool AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin, - HistogramCuts *cuts, bool secure) { - size_t required_cuts = std::min(summary.size, static_cast(max_bin)); +bool AddCutPoint(Context const *ctx, typename SketchType::SummaryContainer const &summary, + int max_bin, HistogramCuts *cuts, bool secure) { + bst_idx_t required_cuts = std::min(summary.size, static_cast(max_bin)); // make a copy of required_cuts for mode selection size_t required_cuts_original = required_cuts; if (secure) { // sync the required_cuts across all workers - collective::Allreduce(&required_cuts, 1); + collective::SafeColl(collective::Allreduce(ctx, &required_cuts, collective::Op::kMax)); } // add the cut points auto &cut_values = cuts->cut_values_.HostVector(); @@ -438,18 +445,18 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const float max_cat{-1.f}; for (size_t fid = 0; fid < reduced.size(); ++fid) { - size_t max_num_bins = std::min(num_cuts[fid], max_bins_); + std::int32_t max_num_bins = std::min(num_cuts[fid], max_bins_); // If vertical and secure mode, we need to sync the max_num_bins aross workers // to create the same global number of cut point bins for easier future processing if (info.IsVerticalFederated() && info.IsSecure()) { - collective::Allreduce(&max_num_bins, 1); + collective::SafeColl(collective::Allreduce(ctx, &max_num_bins, collective::Op::kMax)); } typename WQSketch::SummaryContainer const &a = final_summaries[fid]; if (IsCat(feature_types_, fid)) { max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts)); } else { // use special AddCutPoint scheme for secure vertical federated learning - bool is_nan = AddCutPoint(a, max_num_bins, p_cuts, info.IsSecure()); + bool is_nan = AddCutPoint(ctx, a, max_num_bins, p_cuts, info.IsSecure()); // push a value that is greater than anything if the feature is not empty // i.e. if the last value is not NaN if (!is_nan) { @@ -479,11 +486,11 @@ template class SketchContainerImpl>; HostSketchContainer::HostSketchContainer(Context const *ctx, bst_bin_t max_bins, common::Span ft, - std::vector columns_size, bool use_group) + std::vector columns_size, bool use_group) : SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} { monitor_.Init(__func__); ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) { - auto n_bins = std::min(static_cast(max_bins_), columns_size_[i]); + auto n_bins = std::min(static_cast(max_bins_), columns_size_[i]); n_bins = std::max(n_bins, static_cast(1)); auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); if (!IsCat(this->feature_types_, i)) { diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 4b110f5e0164..d0356ae421c7 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023 by XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include #include @@ -8,11 +8,12 @@ #include #include -#include // std::numeric_limits -#include +#include // std::numeric_limits +#include // for partial_sum #include -#include "../collective/communicator-inl.cuh" +#include "../collective/allgather.h" +#include "../collective/allreduce.h" #include "categorical.h" #include "common.h" #include "device_helpers.cuh" @@ -114,16 +115,16 @@ void CopyTo(Span out, Span src) { // Compute the merge path. common::Span> MergePath( - Span const &d_x, Span const &x_ptr, - Span const &d_y, Span const &y_ptr, - Span out, Span out_ptr) { + Span const &d_x, Span const &x_ptr, + Span const &d_y, Span const &y_ptr, + Span out, Span out_ptr) { auto x_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple( - dh::MakeTransformIterator( + dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] __device__(size_t idx) { return dh::SegmentId(x_ptr, idx); }), d_x.data())); auto y_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple( - dh::MakeTransformIterator( + dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] __device__(size_t idx) { return dh::SegmentId(y_ptr, idx); }), d_y.data())); @@ -173,13 +174,13 @@ common::Span> MergePath( auto scan_key_it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), - [=] __device__(size_t idx) { return dh::SegmentId(out_ptr, idx); }); + [=] XGBOOST_DEVICE(size_t idx) { return dh::SegmentId(out_ptr, idx); }); auto scan_val_it = dh::MakeTransformIterator( - merge_path.data(), [=] __device__(Tuple const &t) -> Tuple { + merge_path.data(), [=] XGBOOST_DEVICE(Tuple const &t) -> Tuple { auto ind = get_ind(t); // == 0 if element is from x // x_counter, y_counter - return thrust::make_tuple(!ind, ind); + return thrust::tuple{!ind, ind}; }); // Compute the index for both x and y (which of the element in a and b are used in each @@ -206,8 +207,8 @@ common::Span> MergePath( // run it in 2 passes to obtain the merge path and then customize the standard merge // algorithm. void MergeImpl(DeviceOrd device, Span const &d_x, - Span const &x_ptr, Span const &d_y, - Span const &y_ptr, Span out, Span out_ptr) { + Span const &x_ptr, Span const &d_y, + Span const &y_ptr, Span out, Span out_ptr) { dh::safe_cuda(cudaSetDevice(device.ordinal)); CHECK_EQ(d_x.size() + d_y.size(), out.size()); CHECK_EQ(x_ptr.size(), out_ptr.size()); @@ -499,7 +500,7 @@ void SketchContainer::FixError() { }); } -void SketchContainer::AllReduce(Context const*, bool is_column_split) { +void SketchContainer::AllReduce(Context const* ctx, bool is_column_split) { dh::safe_cuda(cudaSetDevice(device_.ordinal)); auto world = collective::GetWorldSize(); if (world == 1 || is_column_split) { @@ -508,16 +509,18 @@ void SketchContainer::AllReduce(Context const*, bool is_column_split) { timer_.Start(__func__); // Reduce the overhead on syncing. - size_t global_sum_rows = num_rows_; - collective::Allreduce(&global_sum_rows, 1); - size_t intermediate_num_cuts = + bst_idx_t global_sum_rows = num_rows_; + auto rc = collective::Allreduce(ctx, linalg::MakeVec(&global_sum_rows, 1), collective::Op::kSum); + SafeColl(rc); + bst_idx_t intermediate_num_cuts = std::min(global_sum_rows, static_cast(num_bins_ * kFactor)); this->Prune(intermediate_num_cuts); auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1); size_t n = d_columns_ptr.size(); - collective::Allreduce(&n, 1); + rc = collective::Allreduce(ctx, linalg::MakeVec(&n, 1), collective::Op::kMax); + SafeColl(rc); CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers"; // Get the columns ptr from all workers @@ -527,18 +530,25 @@ void SketchContainer::AllReduce(Context const*, bool is_column_split) { auto offset = rank * d_columns_ptr.size(); thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(), gathered_ptrs.begin() + offset); - collective::AllReduce(device_.ordinal, gathered_ptrs.data().get(), - gathered_ptrs.size()); + rc = collective::Allreduce( + ctx, linalg::MakeVec(gathered_ptrs.data().get(), gathered_ptrs.size(), ctx->Device()), + collective::Op::kSum); + SafeColl(rc); // Get the data from all workers. - std::vector recv_lengths; - dh::caching_device_vector recvbuf; - collective::AllGatherV(device_.ordinal, this->Current().data().get(), - dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf); - collective::Synchronize(device_.ordinal); + std::vector recv_lengths; + HostDeviceVector recvbuf; + rc = collective::AllgatherV( + ctx, linalg::MakeVec(this->Current().data().get(), this->Current().size(), device_), + &recv_lengths, &recvbuf); + collective::SafeColl(rc); + for (std::size_t i = 0; i < recv_lengths.size() - 1; ++i) { + recv_lengths[i] = recv_lengths[i + 1] - recv_lengths[i]; + } + recv_lengths.resize(recv_lengths.size() - 1); // Segment the received data. - auto s_recvbuf = dh::ToSpan(recvbuf); + auto s_recvbuf = recvbuf.DeviceSpan(); std::vector> allworkers; offset = 0; for (int32_t i = 0; i < world; ++i) { diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index f7124f079b6d..898da03a0dce 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -1,8 +1,9 @@ +/** + * Copyright 2020-2024, XGBoost Contributors + */ #ifndef XGBOOST_COMMON_QUANTILE_CUH_ #define XGBOOST_COMMON_QUANTILE_CUH_ -#include - #include "xgboost/span.h" #include "xgboost/data.h" #include "device_helpers.cuh" @@ -32,13 +33,13 @@ struct SketchUnique { class SketchContainer { public: static constexpr float kFactor = WQSketch::kFactor; - using OffsetT = bst_row_t; + using OffsetT = bst_idx_t; static_assert(sizeof(OffsetT) == sizeof(size_t), "Wrong type for sketch element offset."); private: Monitor timer_; HostDeviceVector feature_types_; - bst_row_t num_rows_; + bst_idx_t num_rows_; bst_feature_t num_columns_; int32_t num_bins_; DeviceOrd device_; @@ -94,7 +95,7 @@ class SketchContainer { * \param device GPU ID. */ SketchContainer(HostDeviceVector const& feature_types, int32_t max_bin, - bst_feature_t num_columns, bst_row_t num_rows, DeviceOrd device) + bst_feature_t num_columns, bst_idx_t num_rows, DeviceOrd device) : num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { CHECK(device.IsCUDA()); // Initialize Sketches for this dmatrix diff --git a/src/common/quantile.h b/src/common/quantile.h index 0af93a03e021..59bc3a4f74b1 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023 by XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file quantile.h * \brief util to compute quantiles * \author Tianqi Chen @@ -701,12 +701,12 @@ inline std::vector UnrollGroupWeights(MetaInfo const &info) { auto n_groups = group_ptr.size() - 1; CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight(); - bst_row_t n_samples = info.num_row_; + bst_idx_t n_samples = info.num_row_; std::vector results(n_samples); CHECK_EQ(group_ptr.back(), n_samples) << error::GroupSize() << " the number of rows from the data."; size_t cur_group = 0; - for (bst_row_t i = 0; i < n_samples; ++i) { + for (bst_idx_t i = 0; i < n_samples; ++i) { results[i] = group_weights[cur_group]; if (i == group_ptr[cur_group + 1]) { cur_group++; @@ -719,9 +719,9 @@ inline std::vector UnrollGroupWeights(MetaInfo const &info) { class HistogramCuts; template -std::vector CalcColumnSize(Batch const &batch, bst_feature_t const n_columns, +std::vector CalcColumnSize(Batch const &batch, bst_feature_t const n_columns, size_t const n_threads, IsValid &&is_valid) { - std::vector> column_sizes_tloc(n_threads); + std::vector> column_sizes_tloc(n_threads); for (auto &column : column_sizes_tloc) { column.resize(n_columns, 0); } @@ -759,7 +759,7 @@ std::vector LoadBalance(Batch const &batch, size_t nnz, bst_featu size_t const entries_per_thread = DivRoundUp(total_entries, nthreads); // Need to calculate the size for each batch. - std::vector entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid); + std::vector entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid); std::vector cols_ptr(nthreads + 1, 0); size_t count{0}; size_t current_thread{1}; @@ -791,8 +791,8 @@ class SketchContainerImpl { std::vector> categories_; std::vector const feature_types_; - std::vector columns_size_; - int32_t max_bins_; + std::vector columns_size_; + bst_bin_t max_bins_; bool use_group_ind_{false}; int32_t n_threads_; bool has_categorical_{false}; @@ -805,7 +805,7 @@ class SketchContainerImpl { * \param max_bins maximum number of bins for each feature. * \param use_group whether is assigned to group to data instance. */ - SketchContainerImpl(Context const *ctx, std::vector columns_size, int32_t max_bins, + SketchContainerImpl(Context const *ctx, std::vector columns_size, bst_bin_t max_bins, common::Span feature_types, bool use_group); static bool UseGroup(MetaInfo const &info) { @@ -829,8 +829,8 @@ class SketchContainerImpl { // Gather sketches from all workers. void GatherSketchInfo(Context const *ctx, MetaInfo const &info, std::vector const &reduced, - std::vector *p_worker_segments, - std::vector *p_sketches_scan, + std::vector *p_worker_segments, + std::vector *p_sketches_scan, std::vector *p_global_sketches); // Merge sketches from all workers. void AllReduce(Context const *ctx, MetaInfo const &info, @@ -901,7 +901,7 @@ class HostSketchContainer : public SketchContainerImpl ft, - std::vector columns_size, bool use_group); + std::vector columns_size, bool use_group); template void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing); @@ -998,7 +998,7 @@ class SortedSketchContainer : public SketchContainerImpl ft, - std::vector columns_size, bool use_group) + std::vector columns_size, bool use_group) : SketchContainerImpl{ctx, columns_size, max_bins, ft, use_group} { monitor_.Init(__func__); sketches_.resize(columns_size.size()); diff --git a/src/common/random.h b/src/common/random.h index ece6fa46f16c..6d7a1bb499c9 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -1,5 +1,5 @@ /** - * Copyright 2015-2020, XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file random.h * \brief Utility related to random. * \author Tianqi Chen @@ -19,11 +19,13 @@ #include #include +#include "../collective/broadcast.h" // for Broadcast #include "../collective/communicator-inl.h" #include "algorithm.h" // ArgSort #include "common.h" #include "xgboost/context.h" // Context #include "xgboost/host_device_vector.h" +#include "xgboost/linalg.h" namespace xgboost::common { /*! @@ -227,9 +229,10 @@ class ColumnSampler { } }; -inline auto MakeColumnSampler(Context const*) { +inline auto MakeColumnSampler(Context const* ctx) { std::uint32_t seed = common::GlobalRandomEngine()(); - collective::Broadcast(&seed, sizeof(seed), 0); + auto rc = collective::Broadcast(ctx, linalg::MakeVec(&seed, 1), 0); + collective::SafeColl(rc); auto cs = std::make_shared(seed); return cs; } diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index e6b87ed4b099..acba0feeb2a4 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -78,6 +78,7 @@ struct LambdaRankParam : public XGBoostParameter { // unbiased bool lambdarank_unbiased{false}; + bool lambdarank_normalization{true}; double lambdarank_bias_norm{1.0}; // ndcg bool ndcg_exp_gain{true}; @@ -86,6 +87,7 @@ struct LambdaRankParam : public XGBoostParameter { return lambdarank_pair_method == that.lambdarank_pair_method && lambdarank_num_pair_per_sample == that.lambdarank_num_pair_per_sample && lambdarank_unbiased == that.lambdarank_unbiased && + lambdarank_normalization == that.lambdarank_normalization && lambdarank_bias_norm == that.lambdarank_bias_norm && ndcg_exp_gain == that.ndcg_exp_gain; } bool operator!=(LambdaRankParam const& that) const { return !(*this == that); } @@ -134,6 +136,9 @@ struct LambdaRankParam : public XGBoostParameter { DMLC_DECLARE_FIELD(lambdarank_unbiased) .set_default(false) .describe("Unbiased lambda mart. Use extended IPW to debias click position"); + DMLC_DECLARE_FIELD(lambdarank_normalization) + .set_default(true) + .describe("Whether to normalize the leaf value for lambda rank."); DMLC_DECLARE_FIELD(lambdarank_bias_norm) .set_default(1.0) .set_lower_bound(0.0) diff --git a/src/common/threadpool.h b/src/common/threadpool.h new file mode 100644 index 000000000000..95d1deaaabc3 --- /dev/null +++ b/src/common/threadpool.h @@ -0,0 +1,96 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#pragma once +#include // for condition_variable +#include // for int32_t +#include // for function +#include // for promise +#include // for make_shared +#include // for mutex, unique_lock +#include // for queue +#include // for thread +#include // for invoke_result_t +#include // for move +#include // for vector + +namespace xgboost::common { +/** + * @brief Simple implementation of a thread pool. + */ +class ThreadPool { + std::mutex mu_; + std::queue> tasks_; + std::condition_variable cv_; + std::vector pool_; + bool stop_{false}; + + public: + explicit ThreadPool(std::int32_t n_threads) { + for (std::int32_t i = 0; i < n_threads; ++i) { + pool_.emplace_back([&] { + while (true) { + std::unique_lock lock{mu_}; + cv_.wait(lock, [this] { return !this->tasks_.empty() || stop_; }); + + if (this->stop_) { + if (!tasks_.empty()) { + while (!tasks_.empty()) { + auto fn = tasks_.front(); + tasks_.pop(); + fn(); + } + } + return; + } + + auto fn = tasks_.front(); + tasks_.pop(); + lock.unlock(); + fn(); + } + }); + } + } + + ~ThreadPool() { + std::unique_lock lock{mu_}; + stop_ = true; + lock.unlock(); + + for (auto& t : pool_) { + if (t.joinable()) { + std::unique_lock lock{mu_}; + this->cv_.notify_one(); + lock.unlock(); + } + } + + for (auto& t : pool_) { + if (t.joinable()) { + t.join(); + } + } + } + + /** + * @brief Submit a function that doesn't take any argument. + */ + template > + auto Submit(Fn&& fn) { + // Use shared ptr to make the task copy constructible. + auto p{std::make_shared>()}; + auto fut = p->get_future(); + auto ffn = std::function{[task = std::move(p), fn = std::move(fn)]() mutable { + task->set_value(fn()); + }}; + + std::unique_lock lock{mu_}; + this->tasks_.push(std::move(ffn)); + lock.unlock(); + + cv_.notify_one(); + return fut; + } +}; +} // namespace xgboost::common diff --git a/src/common/timer.cc b/src/common/timer.cc index 99150aa2695e..9b1f49fbd5c8 100644 --- a/src/common/timer.cc +++ b/src/common/timer.cc @@ -1,20 +1,17 @@ -/*! - * Copyright by Contributors 2019 +/** + * Copyright 2019-2024, XGBoost Contributors */ #include "timer.h" -#include #include #include "../collective/communicator-inl.h" #if defined(XGBOOST_USE_NVTX) -#include +#include #endif // defined(XGBOOST_USE_NVTX) -namespace xgboost { -namespace common { - +namespace xgboost::common { void Monitor::Start(std::string const &name) { if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { auto &stats = statistics_map_[name]; @@ -61,9 +58,10 @@ void Monitor::Print() const { kv.second.timer.elapsed) .count()); } + if (stat_map.empty()) { + return; + } LOG(CONSOLE) << "======== Monitor (" << rank << "): " << label_ << " ========"; this->PrintStatistics(stat_map); } - -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/src/data/adapter.h b/src/data/adapter.h index e9a4ad9fc748..0ad1e9e3864c 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -73,11 +73,11 @@ constexpr size_t kAdapterUnknownSize = std::numeric_limits::max(); struct COOTuple { COOTuple() = default; - XGBOOST_DEVICE COOTuple(size_t row_idx, size_t column_idx, float value) + XGBOOST_DEVICE COOTuple(bst_idx_t row_idx, bst_idx_t column_idx, float value) : row_idx(row_idx), column_idx(column_idx), value(value) {} - size_t row_idx{0}; - size_t column_idx{0}; + bst_idx_t row_idx{0}; + bst_idx_t column_idx{0}; float value{0}; }; @@ -136,12 +136,8 @@ class CSRAdapterBatch : public detail::NoMetaInfo { public: class Line { public: - Line(size_t row_idx, size_t size, const unsigned* feature_idx, - const float* values) - : row_idx_(row_idx), - size_(size), - feature_idx_(feature_idx), - values_(values) {} + Line(bst_idx_t row_idx, bst_idx_t size, const unsigned* feature_idx, const float* values) + : row_idx_(row_idx), size_(size), feature_idx_(feature_idx), values_(values) {} size_t Size() const { return size_; } COOTuple GetElement(size_t idx) const { @@ -149,8 +145,8 @@ class CSRAdapterBatch : public detail::NoMetaInfo { } private: - size_t row_idx_; - size_t size_; + bst_idx_t row_idx_; + bst_idx_t size_; const unsigned* feature_idx_; const float* values_; }; @@ -178,29 +174,25 @@ class CSRAdapterBatch : public detail::NoMetaInfo { class CSRAdapter : public detail::SingleBatchDataIter { public: - CSRAdapter(const size_t* row_ptr, const unsigned* feature_idx, - const float* values, size_t num_rows, size_t num_elements, - size_t num_features) - : batch_(row_ptr, feature_idx, values, num_rows, num_elements, - num_features), + CSRAdapter(const size_t* row_ptr, const unsigned* feature_idx, const float* values, + bst_idx_t num_rows, bst_idx_t num_elements, size_t num_features) + : batch_(row_ptr, feature_idx, values, num_rows, num_elements, num_features), num_rows_(num_rows), num_columns_(num_features) {} const CSRAdapterBatch& Value() const override { return batch_; } - size_t NumRows() const { return num_rows_; } - size_t NumColumns() const { return num_columns_; } + bst_idx_t NumRows() const { return num_rows_; } + bst_idx_t NumColumns() const { return num_columns_; } private: CSRAdapterBatch batch_; - size_t num_rows_; - size_t num_columns_; + bst_idx_t num_rows_; + bst_idx_t num_columns_; }; class DenseAdapterBatch : public detail::NoMetaInfo { public: - DenseAdapterBatch(const float* values, size_t num_rows, size_t num_features) - : values_(values), - num_rows_(num_rows), - num_features_(num_features) {} + DenseAdapterBatch(const float* values, bst_idx_t num_rows, bst_idx_t num_features) + : values_(values), num_rows_(num_rows), num_features_(num_features) {} private: class Line { @@ -910,7 +902,7 @@ class SparsePageAdapterBatch { struct Line { Entry const* inst; size_t n; - bst_row_t ridx; + bst_idx_t ridx; COOTuple GetElement(size_t idx) const { return {ridx, inst[idx].index, inst[idx].fvalue}; } size_t Size() const { return n; } }; diff --git a/src/data/array_interface.h b/src/data/array_interface.h index d645c9e755d6..fafe0b6acc8e 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -615,7 +615,12 @@ auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) { case ArrayInterfaceHandler::kF16: { using T = long double; CHECK(sizeof(T) == 16) << error::NoF128(); - return dispatch(T{}); + // Avoid invalid type. + if constexpr (sizeof(T) == 16) { + return dispatch(T{}); + } else { + return dispatch(double{}); + } } case ArrayInterfaceHandler::kI1: { return dispatch(std::int8_t{}); diff --git a/src/data/data.cc b/src/data/data.cc index 24b41640c173..f37a10fa30b8 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -11,7 +11,6 @@ #include // for abs #include // for uint64_t, int32_t, uint8_t, uint32_t #include // for size_t, strcmp, memcpy -#include // for exception #include // for operator<<, basic_ostream, basic_ostream::op... #include // for map, operator!= #include // for accumulate, partial_sum @@ -19,10 +18,10 @@ #include // for remove_pointer_t, remove_reference #include "../collective/communicator-inl.h" // for GetRank, GetWorldSize, Allreduce, IsFederated -#include "../collective/communicator.h" // for Operation +#include "../collective/allgather.h" +#include "../collective/allreduce.h" #include "../common/algorithm.h" // for StableSort #include "../common/api_entry.h" // for XGBAPIThreadLocalEntry -#include "../common/common.h" // for Split #include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData #include "../common/group_data.h" // for ParallelGroupBuilder #include "../common/io.h" // for PeekableInStream @@ -47,7 +46,7 @@ #include "simple_dmatrix.h" // for SimpleDMatrix #include "sparse_page_writer.h" // for SparsePageFormatReg #include "validation.h" // for LabelsCheck, WeightsCheck, ValidateQueryGroup -#include "xgboost/base.h" // for bst_group_t, bst_row_t, bst_float, bst_ulong +#include "xgboost/base.h" // for bst_group_t, bst_idx_t, bst_float, bst_ulong #include "xgboost/context.h" // for Context #include "xgboost/host_device_vector.h" // for HostDeviceVector #include "xgboost/learner.h" // for HostDeviceVector @@ -473,11 +472,11 @@ void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_ << ", must have at least 1 column even if it's empty."; auto const& first = get(array.front()); auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); - is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr); } else { auto const& first = get(j_interface); auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData(first); - is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr); + is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr); } if (is_cuda) { @@ -567,46 +566,6 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { } } -void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, - size_t num) { - CHECK(key); - auto proc = [&](auto cast_d_ptr) { - using T = std::remove_pointer_t; - auto t = linalg::TensorView(common::Span{cast_d_ptr, num}, {num}, DeviceOrd::CPU()); - CHECK(t.CContiguous()); - Json interface { - linalg::ArrayInterface(t) - }; - assert(ArrayInterface<1>{interface}.is_contiguous); - return interface; - }; - // Legacy code using XGBoost dtype, which is a small subset of array interface types. - switch (dtype) { - case xgboost::DataType::kFloat32: { - auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(ctx, key, proc(cast_ptr)); - break; - } - case xgboost::DataType::kDouble: { - auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(ctx, key, proc(cast_ptr)); - break; - } - case xgboost::DataType::kUInt32: { - auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(ctx, key, proc(cast_ptr)); - break; - } - case xgboost::DataType::kUInt64: { - auto cast_ptr = reinterpret_cast(dptr); - this->SetInfoFromHost(ctx, key, proc(cast_ptr)); - break; - } - default: - LOG(FATAL) << "Unknown data type" << static_cast(dtype); - } -} - void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, const void** out_dptr) const { if (dtype == DataType::kFloat32) { @@ -643,41 +602,42 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, } void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) { - if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) { + bool is_col_split = this->IsColumnSplit(); + + if (size != 0 && this->num_col_ != 0 && !is_col_split) { CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns."; CHECK(info); } - if (!std::strcmp(key, "feature_type")) { - feature_type_names.clear(); - for (size_t i = 0; i < size; ++i) { - auto elem = info[i]; - feature_type_names.emplace_back(elem); - } - if (IsColumnSplit()) { - feature_type_names = collective::AllgatherStrings(feature_type_names); - CHECK_EQ(feature_type_names.size(), num_col_) + // Gather column info when data is split by columns + auto gather_columns = [is_col_split, key, n_columns = this->num_col_](auto const& inputs) { + if (is_col_split) { + std::remove_const_t> result; + auto rc = collective::AllgatherStrings(inputs, &result); + collective::SafeColl(rc); + CHECK_EQ(result.size(), n_columns) << "Length of " << key << " must be equal to number of columns."; + return result; } + return inputs; + }; + + if (StringView{key} == "feature_type") { // NOLINT + this->feature_type_names.clear(); + std::copy(info, info + size, std::back_inserter(feature_type_names)); + feature_type_names = gather_columns(feature_type_names); auto& h_feature_types = feature_types.HostVector(); this->has_categorical_ = LoadFeatureType(feature_type_names, &h_feature_types); - } else if (!std::strcmp(key, "feature_name")) { - if (IsColumnSplit()) { - std::vector local_feature_names{}; + } else if (StringView{key} == "feature_name") { // NOLINT + feature_names.clear(); + if (is_col_split) { auto const rank = collective::GetRank(); - for (std::size_t i = 0; i < size; ++i) { - auto elem = std::to_string(rank) + "." + info[i]; - local_feature_names.emplace_back(elem); - } - feature_names = collective::AllgatherStrings(local_feature_names); - CHECK_EQ(feature_names.size(), num_col_) - << "Length of " << key << " must be equal to number of columns."; + std::transform(info, info + size, std::back_inserter(feature_names), + [rank](char const* elem) { return std::to_string(rank) + "." + elem; }); } else { - feature_names.clear(); - for (size_t i = 0; i < size; ++i) { - feature_names.emplace_back(info[i]); - } + std::copy(info, info + size, std::back_inserter(feature_names)); } + feature_names = gather_columns(feature_names); } else { LOG(FATAL) << "Unknown feature info name: " << key; } @@ -770,12 +730,10 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col } } -void MetaInfo::SynchronizeNumberOfColumns(Context const*) { - if (IsColumnSplit()) { - collective::Allreduce(&num_col_, 1); - } else { - collective::Allreduce(&num_col_, 1); - } +void MetaInfo::SynchronizeNumberOfColumns(Context const* ctx) { + auto op = IsColumnSplit() ? collective::Op::kSum : collective::Op::kMax; + auto rc = collective::Allreduce(ctx, linalg::MakeVec(&num_col_, 1), op); + collective::SafeColl(rc); } namespace { @@ -996,7 +954,7 @@ template DMatrix* DMatrix::Create( SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { SparsePage transpose; - common::ParallelGroupBuilder builder(&transpose.offset.HostVector(), + common::ParallelGroupBuilder builder(&transpose.offset.HostVector(), &transpose.data.HostVector()); builder.InitBudget(num_columns, n_threads); long batch_size = static_cast(this->Size()); // NOLINT(*) @@ -1192,7 +1150,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread void SparsePage::PushCSC(const SparsePage &batch) { std::vector& self_data = data.HostVector(); - std::vector& self_offset = offset.HostVector(); + std::vector& self_offset = offset.HostVector(); auto const& other_data = batch.data.ConstHostVector(); auto const& other_offset = batch.offset.ConstHostVector(); @@ -1211,7 +1169,7 @@ void SparsePage::PushCSC(const SparsePage &batch) { return; } - std::vector offset(other_offset.size()); + std::vector offset(other_offset.size()); offset[0] = 0; std::vector data(self_data.size() + other_data.size()); diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index a5156f585441..bc012fd9b439 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -39,7 +39,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { return {row_idx, column_idx, value}; } - [[nodiscard]] __device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const { + [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, bst_feature_t fidx) const { auto const& column = columns_[fidx]; float value = column.valid.Data() == nullptr || column.valid.Check(ridx) ? column(ridx) @@ -47,8 +47,8 @@ class CudfAdapterBatch : public detail::NoMetaInfo { return value; } - [[nodiscard]] XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; } - [[nodiscard]] XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return num_rows_; } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return columns_.size(); } private: common::Span> columns_; @@ -168,13 +168,13 @@ class CupyAdapterBatch : public detail::NoMetaInfo { float value = array_interface_(row_idx, column_idx); return {row_idx, column_idx, value}; } - [[nodiscard]] __device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const { + [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, bst_feature_t fidx) const { float value = array_interface_(ridx, fidx); return value; } - [[nodiscard]] XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); } - [[nodiscard]] XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return array_interface_.Shape(0); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return array_interface_.Shape(1); } private: ArrayInterface<2> array_interface_; @@ -208,8 +208,8 @@ class CupyAdapter : public detail::SingleBatchDataIter { // Returns maximum row length template -std::size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, DeviceOrd device, - float missing) { +bst_idx_t GetRowCounts(const AdapterBatchT batch, common::Span offset, DeviceOrd device, + float missing) { dh::safe_cuda(cudaSetDevice(device.ordinal)); IsValidFunctor is_valid(missing); dh::safe_cuda(cudaMemsetAsync(offset.data(), '\0', offset.size_bytes())); @@ -231,7 +231,7 @@ std::size_t GetRowCounts(const AdapterBatchT batch, common::Span offs // Count elements per row dh::LaunchN(n_samples * stride, [=] __device__(std::size_t idx) { - bst_row_t cnt{0}; + bst_idx_t cnt{0}; auto [ridx, fbeg] = linalg::UnravelIndex(idx, n_samples, stride); SPAN_CHECK(ridx < n_samples); for (bst_feature_t fidx = fbeg; fidx < n_features; fidx += stride) { @@ -245,10 +245,10 @@ std::size_t GetRowCounts(const AdapterBatchT batch, common::Span offs static_cast(cnt)); // NOLINT }); dh::XGBCachingDeviceAllocator alloc; - bst_row_t row_stride = + bst_idx_t row_stride = dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()), thrust::device_pointer_cast(offset.data()) + offset.size(), - static_cast(0), thrust::maximum()); + static_cast(0), thrust::maximum()); return row_stride; } diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index c60fe83893bb..d9ea85919bd8 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -171,11 +171,10 @@ struct WriteCompressedEllpackFunctor { using Tuple = thrust::tuple; __device__ size_t operator()(Tuple out) { - auto e = batch.GetElement(out.get<2>()); + auto e = batch.GetElement(thrust::get<2>(out)); if (is_valid(e)) { // -1 because the scan is inclusive - size_t output_position = - accessor.row_stride * e.row_idx + out.get<1>() - 1; + size_t output_position = accessor.row_stride * e.row_idx + thrust::get<1>(out) - 1; uint32_t bin_idx = 0; if (common::IsCat(feature_types, e.column_idx)) { bin_idx = accessor.SearchBin(e.value, e.column_idx); @@ -192,8 +191,8 @@ template struct TupleScanOp { __device__ Tuple operator()(Tuple a, Tuple b) { // Key equal - if (a.template get<0>() == b.template get<0>()) { - b.template get<1>() += a.template get<1>(); + if (thrust::get<0>(a) == thrust::get<0>(b)) { + thrust::get<1>(b) += thrust::get<1>(a); return b; } // Not equal diff --git a/src/data/file_iterator.cc b/src/data/file_iterator.cc index cebfbdc19f65..1e341447c35a 100644 --- a/src/data/file_iterator.cc +++ b/src/data/file_iterator.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021-2023, XGBoost contributors + * Copyright 2021-2024, XGBoost contributors */ #include "file_iterator.h" @@ -10,7 +10,10 @@ #include // for operator<<, basic_ostream, istringstream #include // for vector -#include "../common/common.h" // for Split +#include "../common/common.h" // for Split +#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec +#include "xgboost/linalg.h" +#include "xgboost/logging.h" // for CHECK #include "xgboost/string_view.h" // for operator<<, StringView namespace xgboost::data { @@ -28,10 +31,10 @@ std::string ValidateFileFormat(std::string const& uri) { for (size_t i = 0; i < arg_list.size(); ++i) { std::istringstream is(arg_list[i]); std::pair kv; - CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format" - << " for key in arg " << i + 1; - CHECK(std::getline(is, kv.second)) << "Invalid uri argument format" - << " for value in arg " << i + 1; + CHECK(std::getline(is, kv.first, '=')) + << "Invalid uri argument format" << " for key in arg " << i + 1; + CHECK(std::getline(is, kv.second)) + << "Invalid uri argument format" << " for value in arg " << i + 1; args.insert(kv); } if (args.find("format") == args.cend()) { @@ -48,4 +51,41 @@ std::string ValidateFileFormat(std::string const& uri) { return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1]; } } + +int FileIterator::Next() { + CHECK(parser_); + if (parser_->Next()) { + row_block_ = parser_->Value(); + + indptr_ = linalg::Make1dInterface(row_block_.offset, row_block_.size + 1); + values_ = linalg::Make1dInterface(row_block_.value, row_block_.offset[row_block_.size]); + indices_ = linalg::Make1dInterface(row_block_.index, row_block_.offset[row_block_.size]); + + size_t n_columns = + *std::max_element(row_block_.index, row_block_.index + row_block_.offset[row_block_.size]); + // dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore + // this condition and just add 1 to n_columns + n_columns += 1; + + XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(), values_.c_str(), n_columns); + + if (row_block_.label) { + auto str = linalg::Make1dInterface(row_block_.label, row_block_.size); + XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str()); + } + if (row_block_.qid) { + auto str = linalg::Make1dInterface(row_block_.qid, row_block_.size); + XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str()); + } + if (row_block_.weight) { + auto str = linalg::Make1dInterface(row_block_.weight, row_block_.size); + XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str()); + } + // Continue iteration + return true; + } else { + // Stop iteration + return false; + } +} } // namespace xgboost::data diff --git a/src/data/file_iterator.h b/src/data/file_iterator.h index c7f23b478879..a4afbabe4077 100644 --- a/src/data/file_iterator.h +++ b/src/data/file_iterator.h @@ -1,20 +1,16 @@ /** - * Copyright 2021-2023, XGBoost contributors + * Copyright 2021-2024, XGBoost contributors */ #ifndef XGBOOST_DATA_FILE_ITERATOR_H_ #define XGBOOST_DATA_FILE_ITERATOR_H_ -#include // for max_element -#include // for size_t #include // for uint32_t #include // for unique_ptr #include // for string #include // for move #include "dmlc/data.h" // for RowBlock, Parser -#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate -#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec -#include "xgboost/logging.h" // for CHECK +#include "xgboost/c_api.h" // for XGDMatrixFree, XGProxyDMatrixCreate namespace xgboost::data { [[nodiscard]] std::string ValidateFileFormat(std::string const& uri); @@ -53,41 +49,7 @@ class FileIterator { XGDMatrixFree(proxy_); } - int Next() { - CHECK(parser_); - if (parser_->Next()) { - row_block_ = parser_->Value(); - using linalg::MakeVec; - - indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1)); - values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size])); - indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size])); - - size_t n_columns = *std::max_element(row_block_.index, - row_block_.index + row_block_.offset[row_block_.size]); - // dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore - // this condition and just add 1 to n_columns - n_columns += 1; - - XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(), - values_.c_str(), n_columns); - - if (row_block_.label) { - XGDMatrixSetDenseInfo(proxy_, "label", row_block_.label, row_block_.size, 1); - } - if (row_block_.qid) { - XGDMatrixSetDenseInfo(proxy_, "qid", row_block_.qid, row_block_.size, 1); - } - if (row_block_.weight) { - XGDMatrixSetDenseInfo(proxy_, "weight", row_block_.weight, row_block_.size, 1); - } - // Continue iteration - return true; - } else { - // Stop iteration - return false; - } - } + int Next(); auto Proxy() -> decltype(proxy_) { return proxy_; } diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 88a38d5cce74..493aded70098 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -193,7 +193,7 @@ float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const { float GHistIndexMatrix::GetFvalue(std::vector const &ptrs, std::vector const &values, std::vector const &mins, - bst_row_t ridx, bst_feature_t fidx, bool is_cat) const { + bst_idx_t ridx, bst_feature_t fidx, bool is_cat) const { if (is_cat) { auto gidx = GetGindex(ridx, fidx); if (gidx == -1) { diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 0bb93fc20900..f1754fe35121 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -149,7 +149,7 @@ class GHistIndexMatrix { /** @brief max_bin for each feature. */ bst_bin_t max_numeric_bins_per_feat; /** @brief base row index for current page (used by external memory) */ - bst_row_t base_rowid{0}; + bst_idx_t base_rowid{0}; [[nodiscard]] bst_bin_t MaxNumBinPerFeat() const { return std::max(static_cast(cut.MaxCategory() + 1), max_numeric_bins_per_feat); @@ -230,7 +230,7 @@ class GHistIndexMatrix { */ [[nodiscard]] std::size_t RowIdx(size_t ridx) const { return row_ptr[ridx - base_rowid]; } - [[nodiscard]] bst_row_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; } + [[nodiscard]] bst_idx_t Size() const { return row_ptr.empty() ? 0 : row_ptr.size() - 1; } [[nodiscard]] bst_feature_t Features() const { return cut.Ptrs().size() - 1; } [[nodiscard]] bool ReadColumnPage(common::AlignedResourceReadStream* fi); @@ -243,7 +243,7 @@ class GHistIndexMatrix { [[nodiscard]] float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const; [[nodiscard]] float GetFvalue(std::vector const& ptrs, std::vector const& values, std::vector const& mins, - bst_row_t ridx, bst_feature_t fidx, bool is_cat) const; + bst_idx_t ridx, bst_feature_t fidx, bool is_cat) const; [[nodiscard]] common::HistogramCuts& Cuts() { return cut; } [[nodiscard]] common::HistogramCuts const& Cuts() const { return cut; } diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index e5aa98278c8e..e581e90ca40b 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -9,11 +9,12 @@ #include // for underlying_type_t #include // for vector -#include "../collective/communicator-inl.h" -#include "../common/categorical.h" // common::IsCat +#include "../collective/allreduce.h" // for Allreduce +#include "../collective/communicator-inl.h" // for IsDistributed +#include "../common/categorical.h" // common::IsCat #include "../common/column_matrix.h" -#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter. -#include "batch_utils.h" // for RegenGHist +#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter. +#include "batch_utils.h" // for RegenGHist #include "gradient_index.h" #include "proxy_dmatrix.h" #include "simple_batch_iterator.h" @@ -95,13 +96,13 @@ void GetCutsFromRef(Context const* ctx, std::shared_ptr ref, bst_featur namespace { // Synchronize feature type in case of empty DMatrix -void SyncFeatureType(Context const*, std::vector* p_h_ft) { +void SyncFeatureType(Context const* ctx, std::vector* p_h_ft) { if (!collective::IsDistributed()) { return; } auto& h_ft = *p_h_ft; - auto n_ft = h_ft.size(); - collective::Allreduce(&n_ft, 1); + bst_idx_t n_ft = h_ft.size(); + collective::SafeColl(collective::Allreduce(ctx, &n_ft, collective::Op::kMax)); if (!h_ft.empty()) { // Check correct size if this is not an empty DMatrix. CHECK_EQ(h_ft.size(), n_ft); @@ -109,7 +110,8 @@ void SyncFeatureType(Context const*, std::vector* p_h_ft) { if (n_ft > 0) { h_ft.resize(n_ft); auto ptr = reinterpret_cast*>(h_ft.data()); - collective::Allreduce(ptr, h_ft.size()); + collective::SafeColl( + collective::Allreduce(ctx, linalg::MakeVec(ptr, h_ft.size()), collective::Op::kMax)); } } } // anonymous namespace @@ -132,7 +134,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p, return HostAdapterDispatch(proxy, [](auto const& value) { return value.NumCols(); }); }; - std::vector column_sizes; + std::vector column_sizes; auto const is_valid = data::IsValidFunctor{missing}; auto nnz_cnt = [&]() { return HostAdapterDispatch(proxy, [&](auto const& value) { @@ -175,7 +177,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p, // We use do while here as the first batch is fetched in ctor if (n_features == 0) { n_features = num_cols(); - collective::Allreduce(&n_features, 1); + collective::SafeColl(collective::Allreduce(ctx, &n_features, collective::Op::kMax)); column_sizes.clear(); column_sizes.resize(n_features, 0); info_.num_col_ = n_features; diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index 09a3976d785c..69a7b1aa2568 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -1,20 +1,18 @@ /** - * Copyright 2020-2023, XGBoost contributors + * Copyright 2020-2024, XGBoost contributors */ #include #include -#include +#include "../collective/allreduce.h" #include "../common/hist_util.cuh" #include "batch_utils.h" // for RegenGHist #include "device_adapter.cuh" #include "ellpack_page.cuh" -#include "gradient_index.h" #include "iterative_dmatrix.h" #include "proxy_dmatrix.cuh" #include "proxy_dmatrix.h" #include "simple_batch_iterator.h" -#include "sparse_page_source.h" namespace xgboost::data { void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, @@ -63,7 +61,8 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, dh::safe_cuda(cudaSetDevice(get_device().ordinal)); if (cols == 0) { cols = num_cols(); - collective::Allreduce(&cols, 1); + auto rc = collective::Allreduce(ctx, linalg::MakeVec(&cols, 1), collective::Op::kMax); + SafeColl(rc); this->info_.num_col_ = cols; } else { CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 7efff7af4bb1..a29fde8423fa 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -171,12 +171,13 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_ } else { LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name(); } - if constexpr (get_value) { - return std::invoke_result_t< - Fn, decltype(std::declval>()->Value())>(); - } else { - return std::invoke_result_t>())>(); - } + } + + if constexpr (get_value) { + return std::invoke_result_t>()->Value())>(); + } else { + return std::invoke_result_t>())>(); } } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 99bf67ba0c86..f54d1c43eda4 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014~2023, XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file simple_dmatrix.cc * \brief the input data structure for gradient boosting * \author Tianqi Chen @@ -13,6 +13,7 @@ #include #include "../collective/communicator-inl.h" // for GetWorldSize, GetRank, Allgather +#include "../collective/allgather.h" #include "../common/error_msg.h" // for InconsistentMaxBin #include "./simple_batch_iterator.h" #include "adapter.h" @@ -59,7 +60,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) { auto& h_data = out_page.data.HostVector(); auto& h_offset = out_page.offset.HostVector(); size_t rptr{0}; - for (bst_row_t i = 0; i < this->Info().num_row_; i++) { + for (bst_idx_t i = 0; i < this->Info().num_row_; i++) { auto inst = batch[i]; auto prev_size = h_data.size(); std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data), @@ -76,8 +77,11 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) { void SimpleDMatrix::ReindexFeatures(Context const* ctx) { if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) { - auto const cols = collective::Allgather(info_.num_col_); - auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul); + std::vector buffer(collective::GetWorldSize()); + buffer[collective::GetRank()] = info_.num_col_; + auto rc = collective::Allgather(ctx, linalg::MakeVec(buffer.data(), buffer.size())); + SafeColl(rc); + auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0); if (offset == 0) { return; } diff --git a/src/data/simple_dmatrix.cuh b/src/data/simple_dmatrix.cuh index 528bea8be80a..e3c241886007 100644 --- a/src/data/simple_dmatrix.cuh +++ b/src/data/simple_dmatrix.cuh @@ -40,7 +40,7 @@ void CopyDataToDMatrix(AdapterBatchT batch, common::Span data, } template -void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, DeviceOrd device, +void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, DeviceOrd device, float missing) { dh::safe_cuda(cudaSetDevice(device.ordinal)); IsValidFunctor is_valid(missing); diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 9cb0e364fb6b..00aeeb5427f0 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023, XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file sparse_page_source.h */ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ @@ -7,23 +7,28 @@ #include // for min #include // for atomic +#include // for remove #include // for async -#include -#include -#include // for mutex -#include -#include -#include // for pair, move -#include - -#include "../common/common.h" -#include "../common/io.h" // for PrivateMmapConstStream -#include "../common/timer.h" // for Monitor, Timer -#include "adapter.h" -#include "proxy_dmatrix.h" // for DMatrixProxy -#include "sparse_page_writer.h" // for SparsePageFormat -#include "xgboost/base.h" -#include "xgboost/data.h" +#include // for unique_ptr +#include // for mutex +#include // for partial_sum +#include // for string +#include // for pair, move +#include // for vector + +#if !defined(XGBOOST_USE_CUDA) +#include "../common/common.h" // for AssertGPUSupport +#endif // !defined(XGBOOST_USE_CUDA) + +#include "../common/io.h" // for PrivateMmapConstStream +#include "../common/threadpool.h" // for ThreadPool +#include "../common/timer.h" // for Monitor, Timer +#include "proxy_dmatrix.h" // for DMatrixProxy +#include "sparse_page_writer.h" // for SparsePageFormat +#include "xgboost/base.h" // for bst_feature_t +#include "xgboost/data.h" // for SparsePage, CSCPage +#include "xgboost/global_config.h" // for GlobalConfigThreadLocalStore +#include "xgboost/logging.h" // for CHECK_EQ namespace xgboost::data { inline void TryDeleteCacheFile(const std::string& file) { @@ -145,6 +150,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl { std::mutex single_threaded_; // The current page. std::shared_ptr page_; + // Workers for fetching data from external memory. + common::ThreadPool workers_; bool at_end_ {false}; float missing_; @@ -158,8 +165,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl { std::shared_ptr cache_info_; using Ring = std::vector>>; - // A ring storing futures to data. Since the DMatrix iterator is forward only, so we - // can pre-fetch data in a ring. + // A ring storing futures to data. Since the DMatrix iterator is forward only, we can + // pre-fetch data in a ring. std::unique_ptr ring_{new Ring}; // Catching exception in pre-fetch threads to prevent segfault. Not always work though, // OOM error can be delayed due to lazy commit. On the bright side, if mmap is used then @@ -177,14 +184,18 @@ class SparsePageSourceImpl : public BatchIteratorImpl { } // An heuristic for number of pre-fetched batches. We can make it part of BatchParam // to let user adjust number of pre-fetched batches when needed. - std::int32_t n_prefetches = std::max(nthreads_, 3); + std::int32_t kPrefetches = 3; + std::int32_t n_prefetches = std::min(nthreads_, kPrefetches); + n_prefetches = std::max(n_prefetches, 1); std::int32_t n_prefetch_batches = std::min(static_cast(n_prefetches), n_batches_); CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; + CHECK_LE(n_prefetch_batches, kPrefetches); std::size_t fetch_it = count_; exce_.Rethrow(); + auto const config = *GlobalConfigThreadLocalStore::Get(); for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring if (ring_->at(fetch_it).valid()) { @@ -192,7 +203,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl { } auto const* self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); - ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self, this]() { + ring_->at(fetch_it) = this->workers_.Submit([fetch_it, self, config, this] { + *GlobalConfigThreadLocalStore::Get() = config; auto page = std::make_shared(); this->exce_.Run([&] { std::unique_ptr> fmt{CreatePageFormat("raw")}; @@ -247,7 +259,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl { public: SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, std::shared_ptr cache) - : missing_{missing}, + : workers_{nthreads}, + missing_{missing}, nthreads_{nthreads}, n_features_{n_features}, n_batches_{n_batches}, diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index a2d84d8485a3..d6ed851c835c 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023 by Contributors + * Copyright 2014-2024, XGBoost Contributors * \file gbtree.cc * \brief gradient boosted tree implementation. * \author Tianqi Chen @@ -11,14 +11,12 @@ #include #include // std::int32_t -#include #include +#include // for iota #include -#include #include #include -#include "../common/common.h" #include "../common/timer.h" #include "../tree/param.h" // TrainParam #include "gbtree_model.h" diff --git a/src/learner.cc b/src/learner.cc index eed9dd5cdcd7..93db7f801407 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -18,7 +18,6 @@ #include // for int32_t, uint32_t, int64_t, uint64_t #include // for atoi #include // for memcpy, size_t, memset -#include // for less #include // for operator<<, setiosflags #include // for back_insert_iterator, distance, back_inserter #include // for numeric_limits @@ -36,7 +35,6 @@ #include "collective/aggregator.h" // for ApplyWithLabels #include "collective/communicator-inl.h" // for Allreduce, Broadcast, GetRank, IsDistributed -#include "collective/communicator.h" // for Operation #include "common/api_entry.h" // for XGBAPIThreadLocalEntry #include "common/charconv.h" // for to_chars, to_chars_result, NumericLimits, from_... #include "common/common.h" // for ToString, Split @@ -209,7 +207,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter return dmlc::Parameter::UpdateAllowUnknown(kwargs); } // sanity check - void Validate(Context const*) { + void Validate(Context const* ctx) { if (!collective::IsDistributed()) { return; } @@ -230,7 +228,8 @@ struct LearnerModelParamLegacy : public dmlc::Parameter std::array sync; std::copy(data.cbegin(), data.cend(), sync.begin()); - collective::Broadcast(sync.data(), sync.size(), 0); + auto rc = collective::Broadcast(ctx, linalg::MakeVec(sync.data(), sync.size()), 0); + collective::SafeColl(rc); CHECK(std::equal(data.cbegin(), data.cend(), sync.cbegin())) << "Different model parameter across workers."; } @@ -755,7 +754,9 @@ class LearnerConfiguration : public Learner { num_feature = std::max(num_feature, static_cast(num_col)); } - collective::Allreduce(&num_feature, 1); + auto rc = + collective::Allreduce(&ctx_, linalg::MakeVec(&num_feature, 1), collective::Op::kMax); + collective::SafeColl(rc); if (num_feature > mparam_.num_feature) { mparam_.num_feature = num_feature; } diff --git a/src/logging.cc b/src/logging.cc index d24c6633d987..4cf74207d2d9 100644 --- a/src/logging.cc +++ b/src/logging.cc @@ -1,14 +1,13 @@ -/*! - * Copyright 2015-2018 by Contributors +/** + * Copyright 2015-2024, XGBoost Contributors * \file logging.cc * \brief Implementation of loggers. * \author Tianqi Chen */ -#include - -#include "xgboost/parameter.h" #include "xgboost/logging.h" +#include // for string + #include "collective/communicator-inl.h" #if !defined(XGBOOST_STRICT_R_MODE) || XGBOOST_STRICT_R_MODE == 0 diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 212a3a027d35..6de0d1f129cb 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -177,7 +177,7 @@ double GroupRankingROC(Context const* ctx, common::Span predts, if (sum_w != 0) { auc /= sum_w; } - CHECK_LE(auc, 1.0f); + CHECK_LE(auc, 1.0 + kRtEps); return auc; } @@ -264,9 +264,14 @@ class EvalAUC : public MetricNoCache { info.weights_.SetDevice(ctx_->Device()); } // We use the global size to handle empty dataset. - std::array meta{info.labels.Size(), preds.Size()}; + std::array meta{info.labels.Size(), preds.Size()}; if (!info.IsVerticalFederated()) { - collective::Allreduce(meta.data(), meta.size()); + auto rc = collective::Allreduce( + ctx_, + linalg::MakeTensorView(DeviceOrd::CPU(), common::Span{meta.data(), meta.size()}, + meta.size()), + collective::Op::kMax); + collective::SafeColl(rc); } if (meta[0] == 0) { // Empty across all workers, which is not supported. @@ -290,8 +295,8 @@ class EvalAUC : public MetricNoCache { auc = collective::GlobalRatio(ctx_, info, auc, static_cast(valid_groups)); if (!std::isnan(auc)) { - CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups - << ", valid groups: " << valid_groups; + CHECK_LE(auc, 1.0 + kRtEps) << "Total AUC across groups: " << auc * valid_groups + << ", valid groups: " << valid_groups; } } else if (meta[0] != meta[1] && meta[1] % meta[0] == 0) { /** @@ -311,7 +316,8 @@ class EvalAUC : public MetricNoCache { } auc = collective::GlobalRatio(ctx_, info, auc, fp * tp); if (!std::isnan(auc)) { - CHECK_LE(auc, 1.0); + CHECK_LE(auc, 1.0 + kRtEps); + auc = std::min(auc, 1.0); } } if (std::isnan(auc)) { diff --git a/src/metric/auc.cu b/src/metric/auc.cu index 4ce10d094a50..59199b092839 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -1,9 +1,9 @@ /** - * Copyright 2021-2023 by XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ +#include // for copy #include -#include #include #include // NOLINT #include @@ -11,7 +11,7 @@ #include #include -#include "../collective/communicator-inl.cuh" +#include "../collective/allreduce.h" #include "../common/algorithm.cuh" // SegmentedArgSort #include "../common/optional_weight.h" // OptionalWeights #include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads @@ -201,13 +201,16 @@ void Transpose(common::Span in, common::Span out, size_t m, }); } -double ScaleClasses(Context const *ctx, common::Span results, +double ScaleClasses(Context const *ctx, bool is_column_split, common::Span results, common::Span local_area, common::Span tp, common::Span auc, size_t n_classes) { - if (collective::IsDistributed()) { - int32_t device = dh::CurrentDevice(); + // With vertical federated learning, only the root has label, other parties are not + // evaluation metrics. + if (collective::IsDistributed() && !(is_column_split && collective::IsFederated())) { + std::int32_t device = dh::CurrentDevice(); CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device); - collective::AllReduce(device, results.data(), results.size()); + auto rc = collective::Allreduce( + ctx, linalg::MakeVec(results.data(), results.size(), ctx->Device()), collective::Op::kSum); } auto reduce_in = dh::MakeTransformIterator( thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) { @@ -334,7 +337,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info, auto local_area = d_results.subspan(0, n_classes); auto tp = d_results.subspan(2 * n_classes, n_classes); auto auc = d_results.subspan(3 * n_classes, n_classes); - return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes); + return ScaleClasses(ctx, info.IsColumnSplit(), d_results, local_area, tp, auc, n_classes); } /** @@ -438,7 +441,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info, tp[c] = 1.0f; } }); - return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes); + return ScaleClasses(ctx, info.IsColumnSplit(), d_results, local_area, tp, auc, n_classes); } void MultiClassSortedIdx(Context const *ctx, common::Span predts, @@ -835,7 +838,7 @@ std::pair GPURankingPRAUC(Context const *ctx, InitCacheOnce(predts, p_cache); dh::device_vector group_ptr(info.group_ptr_.size()); - thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin()); + thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin()); // NOLINT auto d_group_ptr = dh::ToSpan(group_ptr); CHECK_GE(info.group_ptr_.size(), 1) << "Must have at least 1 query group for LTR."; size_t n_groups = info.group_ptr_.size() - 1; diff --git a/src/metric/auc.h b/src/metric/auc.h index 4fe2ecec4dd9..f27a1dda6160 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -1,18 +1,14 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #ifndef XGBOOST_METRIC_AUC_H_ #define XGBOOST_METRIC_AUC_H_ -#include #include -#include #include #include #include #include "../collective/communicator-inl.h" -#include "../common/common.h" -#include "../common/threading_utils.h" #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/metric.h" diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 9c26011aa99f..ec5b9079d7d9 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -10,15 +10,15 @@ #include #include +#include // for accumulate -#include "../collective/communicator-inl.h" -#include "../common/common.h" // MetricNoCache +#include "../common/common.h" // for AssertGPUSupport #include "../common/math.h" #include "../common/optional_weight.h" // OptionalWeights #include "../common/pseudo_huber.h" #include "../common/quantile_loss_utils.h" // QuantileLossParam #include "../common/threading_utils.h" -#include "metric_common.h" +#include "metric_common.h" // MetricNoCache #include "xgboost/collective/result.h" // for SafeColl #include "xgboost/metric.h" diff --git a/src/metric/metric_common.h b/src/metric/metric_common.h index 53c38ff2a8c2..2b9239990f8a 100644 --- a/src/metric/metric_common.h +++ b/src/metric/metric_common.h @@ -9,8 +9,6 @@ #include #include "../collective/aggregator.h" -#include "../collective/communicator-inl.h" -#include "../common/common.h" #include "xgboost/metric.h" namespace xgboost { diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index acaef7cf7e84..e51509fc7339 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -9,8 +9,8 @@ #include #include #include +#include // for accumulate -#include "../collective/communicator-inl.h" #include "../common/math.h" #include "../common/threading_utils.h" #include "metric_common.h" // MetricNoCache diff --git a/src/metric/survival_metric.cu b/src/metric/survival_metric.cu index c64fece6c1d3..9c57be3ab2b5 100644 --- a/src/metric/survival_metric.cu +++ b/src/metric/survival_metric.cu @@ -9,10 +9,9 @@ #include #include +#include // for accumulate #include -#include "../collective/communicator-inl.h" -#include "../common/math.h" #include "../common/survival_util.h" #include "../common/threading_utils.h" #include "metric_common.h" // MetricNoCache diff --git a/src/objective/adaptive.h b/src/objective/adaptive.h index cbe69e79a6cc..1a7aef0516d1 100644 --- a/src/objective/adaptive.h +++ b/src/objective/adaptive.h @@ -9,8 +9,6 @@ #include // std::vector #include "../collective/aggregator.h" -#include "../collective/communicator-inl.h" -#include "../common/common.h" #include "xgboost/base.h" // bst_node_t #include "xgboost/context.h" // Context #include "xgboost/data.h" // MetaInfo @@ -42,7 +40,7 @@ inline void UpdateLeafValues(Context const* ctx, std::vector* p_quantiles auto& quantiles = *p_quantiles; auto const& h_node_idx = nidx; - size_t n_leaf = collective::GlobalMax(ctx, info, h_node_idx.size()); + bst_idx_t n_leaf = collective::GlobalMax(ctx, info, static_cast(h_node_idx.size())); CHECK(quantiles.empty() || quantiles.size() == n_leaf); if (quantiles.empty()) { quantiles.resize(n_leaf, std::numeric_limits::quiet_NaN()); diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index 0c9d1262a204..36495d0caa88 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -222,7 +222,7 @@ class LambdaRankObj : public FitIntercept { }; MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop); - if (sum_lambda > 0.0) { + if (sum_lambda > 0.0 && param_.lambdarank_normalization) { double norm = std::log2(1.0 + sum_lambda) / sum_lambda; std::transform(g_gpair.Values().data(), g_gpair.Values().data() + g_gpair.Size(), g_gpair.Values().data(), [norm](GradientPair const& g) { return g * norm; }); @@ -474,7 +474,6 @@ class LambdaRankMAP : public LambdaRankObj { public: void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, const MetaInfo& info, linalg::Matrix* out_gpair) { - CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective."; if (ctx_->IsCUDA()) { return cuda_impl::LambdaRankGetGradientMAP( ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()), @@ -564,7 +563,6 @@ class LambdaRankPairwise : public LambdaRankObj& predt, const MetaInfo& info, linalg::Matrix* out_gpair) { - CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective."; if (ctx_->IsCUDA()) { return cuda_impl::LambdaRankGetGradientPairwise( ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->Device()), @@ -610,6 +608,13 @@ class LambdaRankPairwise : public LambdaRankObjRankEvalMetric("ndcg"); } + + [[nodiscard]] Json DefaultMetricConfig() const override { + Json config{Object{}}; + config["name"] = String{DefaultEvalMetric()}; + config["lambdarank_param"] = ToJson(param_); + return config; + } }; #if !defined(XGBOOST_USE_CUDA) diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index 30eba2fdcf2e..25c5d138ccc7 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -266,12 +266,13 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptrWeightNorm(); + auto norm = p_cache->Param().lambdarank_normalization; thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.Size(), [=] XGBOOST_DEVICE(std::size_t i) mutable { auto g = dh::SegmentId(d_gptr, i); auto sum_lambda = thrust::get<2>(d_max_lambdas[g]); // Normalization - if (sum_lambda > 0.0) { + if (sum_lambda > 0.0 && norm) { double norm = std::log2(1.0 + sum_lambda) / sum_lambda; d_gpair(i, 0) *= norm; } diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 06b8079ee134..32992e3134ad 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 by XGBoost Contributors + * Copyright 2017-2024, XGBoost Contributors */ #include // for max, fill, min #include // for any, any_cast @@ -12,7 +12,7 @@ #include // for vector #include "../collective/communicator-inl.h" // for Allreduce, IsDistributed -#include "../collective/communicator.h" // for Operation +#include "../collective/allreduce.h" #include "../common/bitfield.h" // for RBitField8 #include "../common/categorical.h" // for IsCat, Decision #include "../common/common.h" // for DivRoundUp @@ -184,7 +184,7 @@ void FVecDrop(std::size_t const block_size, std::size_t const fvec_offset, static std::size_t constexpr kUnroll = 8; struct SparsePageView { - bst_row_t base_rowid; + bst_idx_t base_rowid; HostSparsePageView view; explicit SparsePageView(SparsePage const *p) : base_rowid{p->base_rowid} { view = p->GetView(); } @@ -193,7 +193,7 @@ struct SparsePageView { }; struct SingleInstanceView { - bst_row_t base_rowid{}; + bst_idx_t base_rowid{}; SparsePage::Inst const &inst; explicit SingleInstanceView(SparsePage::Inst const &instance) : inst{instance} {} @@ -214,7 +214,7 @@ struct GHistIndexMatrixView { std::vector const& values_; public: - size_t base_rowid; + bst_idx_t base_rowid; public: GHistIndexMatrixView(GHistIndexMatrix const &_page, uint64_t n_feat, @@ -292,7 +292,7 @@ class AdapterView { [[nodiscard]] size_t Size() const { return adapter_->NumRows(); } - bst_row_t const static base_rowid = 0; // NOLINT + bst_idx_t const static base_rowid = 0; // NOLINT }; template @@ -461,11 +461,17 @@ class ColumnSplitHelper { return tree_offsets_[tree_index] * n_rows_ + row_id * tree_sizes_[tree_index] + node_id; } - void AllreduceBitVectors(Context const*) { - collective::Allreduce(decision_storage_.data(), - decision_storage_.size()); - collective::Allreduce(missing_storage_.data(), - missing_storage_.size()); + void AllreduceBitVectors(Context const *ctx) { + auto rc = collective::Success() << [&] { + return collective::Allreduce( + ctx, linalg::MakeVec(decision_storage_.data(), decision_storage_.size()), + collective::Op::kBitwiseOR); + } << [&] { + return collective::Allreduce( + ctx, linalg::MakeVec(missing_storage_.data(), missing_storage_.size()), + collective::Op::kBitwiseAND); + }; + collective::SafeColl(rc); } void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 7dcb5b5fc0f8..29fb6bb6a162 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 by XGBoost Contributors + * Copyright 2017-2024, XGBoost Contributors */ #include #include @@ -11,7 +11,7 @@ #include // for any, any_cast #include -#include "../collective/communicator-inl.cuh" +#include "../collective/allreduce.h" #include "../common/bitfield.h" #include "../common/categorical.h" #include "../common/common.h" @@ -67,12 +67,12 @@ struct TreeView { struct SparsePageView { common::Span d_data; - common::Span d_row_ptr; + common::Span d_row_ptr; bst_feature_t num_features; SparsePageView() = default; XGBOOST_DEVICE SparsePageView(common::Span data, - common::Span row_ptr, + common::Span row_ptr, bst_feature_t num_features) : d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {} [[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const { @@ -113,7 +113,7 @@ struct SparsePageLoader { float* smem; __device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features, - bst_row_t num_rows, size_t entry_start, float) + bst_idx_t num_rows, size_t entry_start, float) : use_shared(use_shared), data(data) { extern __shared__ float _smem[]; @@ -146,7 +146,7 @@ struct SparsePageLoader { struct EllpackLoader { EllpackDeviceAccessor const& matrix; - XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_row_t, + XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_idx_t, size_t, float) : matrix{m} {} [[nodiscard]] __device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const { @@ -177,7 +177,7 @@ struct DeviceAdapterLoader { using BatchT = Batch; XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared, - bst_feature_t num_features, bst_row_t num_rows, + bst_feature_t num_features, bst_idx_t num_rows, size_t entry_start, float missing) : batch{batch}, columns{num_features}, use_shared{use_shared}, is_valid{missing} { extern __shared__ float _smem[]; @@ -215,7 +215,7 @@ struct DeviceAdapterLoader { }; template -__device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree, +__device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const &tree, Loader *loader) { bst_node_t nidx = 0; RegTree::Node n = tree.d_tree[nidx]; @@ -230,7 +230,7 @@ __device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree, } template -__device__ float GetLeafWeight(bst_row_t ridx, TreeView const &tree, +__device__ float GetLeafWeight(bst_idx_t ridx, TreeView const &tree, Loader *loader) { bst_node_t nidx = -1; if (tree.HasCategoricalSplit()) { @@ -255,7 +255,7 @@ PredictLeafKernel(Data data, common::Span d_nodes, size_t tree_begin, size_t tree_end, size_t num_features, size_t num_rows, size_t entry_start, bool use_shared, float missing) { - bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x; + bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x; if (ridx >= num_rows) { return; } @@ -664,7 +664,7 @@ __global__ void MaskBitVectorKernel( } } -__device__ bst_node_t GetLeafIndexByBitVector(bst_row_t ridx, TreeView const& tree, +__device__ bst_node_t GetLeafIndexByBitVector(bst_idx_t ridx, TreeView const& tree, BitVector const& decision_bits, BitVector const& missing_bits, std::size_t num_nodes, std::size_t tree_offset) { @@ -682,7 +682,7 @@ __device__ bst_node_t GetLeafIndexByBitVector(bst_row_t ridx, TreeView const& tr return nidx; } -__device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree, +__device__ float GetLeafWeightByBitVector(bst_idx_t ridx, TreeView const& tree, BitVector const& decision_bits, BitVector const& missing_bits, std::size_t num_nodes, std::size_t tree_offset) { @@ -817,10 +817,18 @@ class ColumnSplitHelper { void AllReduceBitVectors(dh::caching_device_vector* decision_storage, dh::caching_device_vector* missing_storage) const { - collective::AllReduce( - ctx_->Ordinal(), decision_storage->data().get(), decision_storage->size()); - collective::AllReduce( - ctx_->Ordinal(), missing_storage->data().get(), missing_storage->size()); + auto rc = collective::Success() << [&] { + return collective::Allreduce( + ctx_, + linalg::MakeVec(decision_storage->data().get(), decision_storage->size(), ctx_->Device()), + collective::Op::kBitwiseOR); + } << [&] { + return collective::Allreduce( + ctx_, + linalg::MakeVec(missing_storage->data().get(), missing_storage->size(), ctx_->Device()), + collective::Op::kBitwiseAND); + }; + collective::SafeColl(rc); } void ResizeBitVectors(dh::caching_device_vector* decision_storage, @@ -1171,7 +1179,7 @@ class GPUPredictor : public xgboost::Predictor { auto max_shared_memory_bytes = ConfigureDevice(ctx_->Device()); const MetaInfo& info = p_fmat->Info(); - bst_row_t num_rows = info.num_row_; + bst_idx_t num_rows = info.num_row_; if (tree_end == 0 || tree_end > model.trees.size()) { tree_end = static_cast(model.trees.size()); } @@ -1196,7 +1204,7 @@ class GPUPredictor : public xgboost::Predictor { for (auto const& batch : p_fmat->GetBatches()) { batch.data.SetDevice(ctx_->Device()); batch.offset.SetDevice(ctx_->Device()); - bst_row_t batch_offset = 0; + bst_idx_t batch_offset = 0; SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(), model.learner_model_param->num_feature}; size_t num_rows = batch.Size(); @@ -1219,7 +1227,7 @@ class GPUPredictor : public xgboost::Predictor { } } else { for (auto const& batch : p_fmat->GetBatches(ctx_, BatchParam{})) { - bst_row_t batch_offset = 0; + bst_idx_t batch_offset = 0; EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->Device())}; size_t num_rows = batch.Size(); auto grid = diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 019804eda31c..2a6d1b9c58db 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -9,7 +9,7 @@ #include // for string, to_string #include "../gbm/gbtree_model.h" // for GBTreeModel -#include "xgboost/base.h" // for bst_float, Args, bst_group_t, bst_row_t +#include "xgboost/base.h" // for bst_float, Args, bst_group_t, bst_idx_t #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for MetaInfo #include "xgboost/host_device_vector.h" // for HostDeviceVector @@ -34,7 +34,7 @@ Predictor* Predictor::Create(std::string const& name, Context const* ctx) { } template -void ValidateBaseMarginShape(linalg::Tensor const& margin, bst_row_t n_samples, +void ValidateBaseMarginShape(linalg::Tensor const& margin, bst_idx_t n_samples, bst_group_t n_groups) { // FIXME: Bindings other than Python doesn't have shape. std::string expected{"Invalid shape of base_margin. Expected: (" + std::to_string(n_samples) + diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index 4360c0b1314e..c3065ad5f135 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -1,24 +1,28 @@ /** - * Copyright 2021-2023 XGBoost contributors + * Copyright 2021-2023, XGBoost contributors * \file common_row_partitioner.h * \brief Common partitioner logic for hist and approx methods. */ #ifndef XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ #define XGBOOST_TREE_COMMON_ROW_PARTITIONER_H_ -#include // std::all_of -#include // std::uint32_t -#include // std::numeric_limits -#include - -#include "../collective/communicator-inl.h" -#include "../common/linalg_op.h" // cbegin -#include "../common/numeric.h" // Iota -#include "../common/partition_builder.h" -#include "hist/expand_entry.h" // CPUExpandEntry -#include "xgboost/base.h" -#include "xgboost/context.h" // Context -#include "xgboost/linalg.h" // TensorView +#include // for all_of, fill +#include // for uint32_t +#include // for numeric_limits +#include // for vector + +#include "../collective/allreduce.h" // for Allreduce +#include "../common/bitfield.h" // for RBitField8 +#include "../common/linalg_op.h" // for cbegin +#include "../common/numeric.h" // for Iota +#include "../common/partition_builder.h" // for PartitionBuilder +#include "../common/row_set.h" // for RowSetCollection +#include "../common/threading_utils.h" // for ParallelFor2d +#include "xgboost/base.h" // for bst_row_t +#include "xgboost/collective/result.h" // for Success, SafeColl +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for TensorView +#include "xgboost/span.h" // for Span namespace xgboost::tree { @@ -28,7 +32,7 @@ class ColumnSplitHelper { public: ColumnSplitHelper() = default; - ColumnSplitHelper(bst_row_t num_row, + ColumnSplitHelper(bst_idx_t num_row, common::PartitionBuilder* partition_builder, common::RowSetCollection* row_set_collection) : partition_builder_{partition_builder}, row_set_collection_{row_set_collection} { @@ -39,7 +43,7 @@ class ColumnSplitHelper { } template - void Partition(common::BlockedSpace2d const& space, std::int32_t n_threads, + void Partition(Context const* ctx, common::BlockedSpace2d const& space, std::int32_t n_threads, GHistIndexMatrix const& gmat, common::ColumnMatrix const& column_matrix, std::vector const& nodes, std::vector const& split_conditions, RegTree const* p_tree) { @@ -56,10 +60,12 @@ class ColumnSplitHelper { }); // Then aggregate the bit vectors across all the workers. - collective::Allreduce(decision_storage_.data(), - decision_storage_.size()); - collective::Allreduce(missing_storage_.data(), - missing_storage_.size()); + auto rc = collective::Success() << [&] { + return collective::Allreduce(ctx, &decision_storage_, collective::Op::kBitwiseOR); + } << [&] { + return collective::Allreduce(ctx, &missing_storage_, collective::Op::kBitwiseAND); + }; + collective::SafeColl(rc); // Finally use the bit vectors to partition the rows. common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) { @@ -85,10 +91,10 @@ class ColumnSplitHelper { class CommonRowPartitioner { public: - bst_row_t base_rowid = 0; + bst_idx_t base_rowid = 0; CommonRowPartitioner() = default; - CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid, + CommonRowPartitioner(Context const* ctx, bst_idx_t num_row, bst_idx_t _base_rowid, bool is_col_split) : base_rowid{_base_rowid}, is_col_split_{is_col_split} { row_set_collection_.Clear(); @@ -220,7 +226,7 @@ class CommonRowPartitioner { // Store results in intermediate buffers from partition_builder_ if (is_col_split_) { column_split_helper_.Partition( - space, ctx->Threads(), gmat, column_matrix, nodes, split_conditions, p_tree); + ctx, space, ctx->Threads(), gmat, column_matrix, nodes, split_conditions, p_tree); } else { common::ParallelFor2d(space, ctx->Threads(), [&](size_t node_in_set, common::Range1d r) { size_t begin = r.begin(); diff --git a/src/tree/fit_stump.cu b/src/tree/fit_stump.cu index 832d40754ec9..dd71465df1cc 100644 --- a/src/tree/fit_stump.cu +++ b/src/tree/fit_stump.cu @@ -47,8 +47,10 @@ void FitStump(Context const* ctx, MetaInfo const& info, thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it, thrust::make_discard_iterator(), dh::tbegin(d_sum.Values())); - collective::GlobalSum(info, ctx->Device(), reinterpret_cast(d_sum.Values().data()), - d_sum.Size() * 2); + auto rc = collective::GlobalSum(ctx, info, + linalg::MakeVec(reinterpret_cast(d_sum.Values().data()), + d_sum.Size() * 2, ctx->Device())); + SafeColl(rc); thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets, [=] XGBOOST_DEVICE(std::size_t i) mutable { diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index ceb322c28616..5e225a13f142 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -1,11 +1,11 @@ /** - * Copyright 2020-2023, XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include // std::max #include #include -#include "../../collective/communicator-inl.cuh" +#include "../../collective/allgather.h" #include "../../common/categorical.h" #include "../../data/ellpack_page.cuh" #include "evaluate_splits.cuh" @@ -413,8 +413,14 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector all_candidate_storage(out_splits.size() * world_size); auto all_candidates = dh::ToSpan(all_candidate_storage); - collective::AllGather(device_.ordinal, out_splits.data(), all_candidates.data(), - out_splits.size() * sizeof(DeviceSplitCandidate)); + auto current_rank = + all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size()); + dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(), + out_splits.size() * sizeof(DeviceSplitCandidate), + cudaMemcpyDeviceToDevice)); + auto rc = collective::Allgather( + ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device())); + collective::SafeColl(rc); // Reduce to get the best candidate from all workers. dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(), diff --git a/src/tree/gpu_hist/gradient_based_sampler.cu b/src/tree/gpu_hist/gradient_based_sampler.cu index 58add0a9354f..f9a3819ad6e0 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cu +++ b/src/tree/gpu_hist/gradient_based_sampler.cu @@ -1,13 +1,13 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include +#include // for sort #include #include #include -#include #include // for size_t #include #include @@ -277,7 +277,7 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(Context const* c common::Span gpair, DMatrix* dmat) { auto cuctx = ctx->CUDACtx(); - bst_row_t n_rows = dmat->Info().num_row_; + bst_idx_t n_rows = dmat->Info().num_row_; size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex( gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_); diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index edd52ba22fc3..654c3c6627f3 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -12,6 +12,7 @@ #include // for move #include // for vector +#include "../../collective/allgather.h" #include "../../common/categorical.h" // for CatBitField #include "../../common/hist_util.h" // for GHistRow, HistogramCuts #include "../../common/linalg_op.h" // for cbegin, cend, begin @@ -35,7 +36,7 @@ template std::enable_if_t || std::is_same_v, std::vector> -AllgatherColumnSplit(std::vector const &entries) { +AllgatherColumnSplit(Context const *ctx, std::vector const &entries) { auto const n_entries = entries.size(); // First, gather all the primitive fields. @@ -52,7 +53,7 @@ AllgatherColumnSplit(std::vector const &entries) { serialized_entries.emplace_back(std::move(out)); } - auto all_serialized = collective::VectorAllgatherV(serialized_entries); + auto all_serialized = collective::VectorAllgatherV(ctx, serialized_entries); CHECK_GE(all_serialized.size(), local_entries.size()); std::vector all_entries(all_serialized.size()); @@ -422,7 +423,7 @@ class HistEvaluator { // Note that under secure vertical setting, only the label owner is able to evaluate the split // based on the global histogram. The other parties will receive the final best splits // allgather is capable of performing this (0-gain entries for non-label owners), - auto all_entries = AllgatherColumnSplit(entries); + auto all_entries = AllgatherColumnSplit(ctx_, entries); for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) { for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { entries[nidx_in_set].split.Update( @@ -665,7 +666,7 @@ class HistMultiEvaluator { if (is_col_split_) { // With column-wise data split, we gather the best splits from all the workers and update the // expand entries accordingly. - auto all_entries = AllgatherColumnSplit(entries); + auto all_entries = AllgatherColumnSplit(ctx_, entries); for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) { for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { entries[nidx_in_set].split.Update( diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index d4cea58d0f72..b5fe3b4dc9fd 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -1,5 +1,5 @@ /** - * Copyright 2021-2023 by XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_ #define XGBOOST_TREE_HIST_HISTOGRAM_H_ @@ -7,26 +7,24 @@ #include // for max #include // for size_t #include // for int32_t -#include // for function #include // for move #include // for vector -#include "../../collective/communicator-inl.h" // for Allreduce -#include "../../collective/communicator.h" // for Operation -#include "../../common/hist_util.h" // for GHistRow, ParallelGHi... -#include "../../common/row_set.h" // for RowSetCollection -#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d -#include "../../data/gradient_index.h" // for GHistIndexMatrix -#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry -#include "hist_cache.h" // for BoundedHistCollection -#include "param.h" // for HistMakerTrainParam -#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t -#include "xgboost/context.h" // for Context -#include "xgboost/data.h" // for BatchIterator, BatchSet -#include "xgboost/linalg.h" // for MatrixView, All, Vect... -#include "xgboost/logging.h" // for CHECK_GE -#include "xgboost/span.h" // for Span -#include "xgboost/tree_model.h" // for RegTree +#include "../../collective/allreduce.h" // for Allreduce +#include "../../common/hist_util.h" // for GHistRow, ParallelGHi... +#include "../../common/row_set.h" // for RowSetCollection +#include "../../common/threading_utils.h" // for ParallelFor2d, Range1d, BlockedSpace2d +#include "../../data/gradient_index.h" // for GHistIndexMatrix +#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry +#include "hist_cache.h" // for BoundedHistCollection +#include "param.h" // for HistMakerTrainParam +#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_bin_t +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for BatchIterator, BatchSet +#include "xgboost/linalg.h" // for MatrixView, All, Vect... +#include "xgboost/logging.h" // for CHECK_GE +#include "xgboost/span.h" // for Span +#include "xgboost/tree_model.h" // for RegTree namespace xgboost::tree { /** @@ -173,7 +171,7 @@ class HistogramBuilder { } } - void SyncHistogram(Context const *, RegTree const *p_tree, + void SyncHistogram(Context const *ctx, RegTree const *p_tree, std::vector const &nodes_to_build, std::vector const &nodes_to_trick) { auto n_total_bins = buffer_.TotalBins(); @@ -189,8 +187,10 @@ class HistogramBuilder { CHECK(!nodes_to_build.empty()); auto first_nidx = nodes_to_build.front(); std::size_t n = n_total_bins * nodes_to_build.size() * 2; - collective::Allreduce( - reinterpret_cast(this->hist_[first_nidx].data()), n); + auto rc = collective::Allreduce( + ctx, linalg::MakeVec(reinterpret_cast(this->hist_[first_nidx].data()), n), + collective::Op::kSum); + SafeColl(rc); } if (is_distributed_ && is_col_split_ && is_secure_) { @@ -201,8 +201,9 @@ class HistogramBuilder { // AllReduce is the most efficient way of achieving the global histogram auto first_nidx = nodes_to_build.front(); std::size_t n = n_total_bins * nodes_to_build.size() * 2; - collective::Allreduce( - reinterpret_cast(this->hist_[first_nidx].data()), n); + collective::SafeColl(collective::Allreduce( + ctx, linalg::MakeVec(reinterpret_cast(this->hist_[first_nidx].data()), n), + collective::Op::kSum)); } common::BlockedSpace2d const &subspace = diff --git a/src/tree/hist/param.cc b/src/tree/hist/param.cc index bd8d7a85c407..10895d5111b8 100644 --- a/src/tree/hist/param.cc +++ b/src/tree/hist/param.cc @@ -1,18 +1,22 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #include "param.h" +#include // for binary #include // for string -#include "../../collective/communicator-inl.h" // for GetRank, Broadcast +#include "../../collective/broadcast.h" // for Broadcast +#include "../../collective/communicator-inl.h" // for GetRank #include "xgboost/json.h" // for Object, Json +#include "xgboost/linalg.h" // for MakeVec #include "xgboost/tree_model.h" // for RegTree namespace xgboost::tree { DMLC_REGISTER_PARAMETER(HistMakerTrainParam); -void HistMakerTrainParam::CheckTreesSynchronized(Context const*, RegTree const* local_tree) const { +void HistMakerTrainParam::CheckTreesSynchronized(Context const* ctx, + RegTree const* local_tree) const { if (!this->debug_synchronize) { return; } @@ -24,7 +28,15 @@ void HistMakerTrainParam::CheckTreesSynchronized(Context const*, RegTree const* local_tree->SaveModel(&model); } Json::Dump(model, &s_model, std::ios::binary); - collective::Broadcast(&s_model, 0); + + auto nchars{static_cast(s_model.size())}; + auto rc = collective::Success() << [&] { + return collective::Broadcast(ctx, linalg::MakeVec(&nchars, 1), 0); + } << [&] { + s_model.resize(nchars); + return collective::Broadcast(ctx, linalg::MakeVec(s_model.data(), s_model.size()), 0); + }; + collective::SafeColl(rc); RegTree ref_tree{}; // rank 0 tree auto j_ref_tree = Json::Load(StringView{s_model}, std::ios::binary); diff --git a/src/tree/hist/sampler.h b/src/tree/hist/sampler.h index 803e40d547bf..11b4ac1c6f16 100644 --- a/src/tree/hist/sampler.h +++ b/src/tree/hist/sampler.h @@ -54,7 +54,7 @@ inline void SampleGradient(Context const* ctx, TrainParam param, if (param.subsample >= 1.0) { return; } - bst_row_t n_samples = out.Shape(0); + bst_idx_t n_samples = out.Shape(0); auto& rnd = common::GlobalRandom(); #if XGBOOST_CUSTOMIZE_GLOBAL_PRNG diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index f18b519264a0..45834cc7755e 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1,5 +1,5 @@ /** - * Copyright 2015-2023, XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file tree_model.cc * \brief model structure for tree */ @@ -8,6 +8,7 @@ #include #include +#include // for array #include #include #include @@ -15,7 +16,7 @@ #include #include "../common/categorical.h" -#include "../common/common.h" // for EscapeU8 +#include "../common/common.h" // for EscapeU8 #include "../predictor/predict_fn.h" #include "io_utils.h" // for GetElem #include "param.h" @@ -31,26 +32,50 @@ namespace tree { DMLC_REGISTER_PARAMETER(TrainParam); } +namespace { +template +std::enable_if_t, std::string> ToStr(Float value) { + int32_t constexpr kFloatMaxPrecision = std::numeric_limits::max_digits10; + static_assert(std::is_floating_point::value, + "Use std::to_string instead for non-floating point values."); + std::stringstream ss; + ss << std::setprecision(kFloatMaxPrecision) << value; + return ss.str(); +} + +template +std::string ToStr(linalg::VectorView value, bst_target_t limit) { + int32_t constexpr kFloatMaxPrecision = std::numeric_limits::max_digits10; + static_assert(std::is_floating_point::value, + "Use std::to_string instead for non-floating point values."); + std::stringstream ss; + ss << std::setprecision(kFloatMaxPrecision); + if (value.Size() == 1) { + ss << value(0); + return ss.str(); + } + CHECK_GE(limit, 2); + auto n = std::min(static_cast(value.Size() - 1), limit - 1); + ss << "["; + for (std::size_t i = 0; i < n; ++i) { + ss << value(i) << ", "; + } + if (value.Size() > limit) { + ss << "..., "; + } + ss << value(value.Size() - 1) << "]"; + return ss.str(); +} +} // namespace /*! * \brief Base class for dump model implementation, modeling closely after code generator. */ class TreeGenerator { protected: - static int32_t constexpr kFloatMaxPrecision = - std::numeric_limits::max_digits10; FeatureMap const& fmap_; std::stringstream ss_; bool const with_stats_; - template - static std::string ToStr(Float value) { - static_assert(std::is_floating_point::value, - "Use std::to_string instead for non-floating point values."); - std::stringstream ss; - ss << std::setprecision(kFloatMaxPrecision) << value; - return ss.str(); - } - static std::string Tabs(uint32_t n) { std::string res; for (uint32_t i = 0; i < n; ++i) { @@ -258,10 +283,10 @@ class TextGenerator : public TreeGenerator { kLeafTemplate, {{"{tabs}", SuperT::Tabs(depth)}, {"{nid}", std::to_string(nid)}, - {"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, + {"{leaf}", ToStr(tree[nid].LeafValue())}, {"{stats}", with_stats_ ? SuperT::Match(kStatTemplate, - {{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); + {{"{cover}", ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); return result; } @@ -311,14 +336,14 @@ class TextGenerator : public TreeGenerator { static std::string const kQuantitiveTemplate = "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; auto cond = tree[nid].SplitCond(); - return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); + return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth); } std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { auto cond = tree[nid].SplitCond(); static std::string const kNodeTemplate = "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; - return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); + return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth); } std::string Categorical(RegTree const &tree, int32_t nid, @@ -336,8 +361,8 @@ class TextGenerator : public TreeGenerator { static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}"; std::string const result = SuperT::Match( kStatTemplate, - {{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)}, - {"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}); + {{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)}, + {"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}}); return result; } @@ -393,11 +418,11 @@ class JsonGenerator : public TreeGenerator { std::string result = SuperT::Match( kLeafTemplate, {{"{nid}", std::to_string(nid)}, - {"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, + {"{leaf}", ToStr(tree[nid].LeafValue())}, {"{stat}", with_stats_ ? SuperT::Match( kStatTemplate, {{"{sum_hess}", - SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); + ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); return result; } @@ -468,7 +493,7 @@ class JsonGenerator : public TreeGenerator { R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" R"I("missing": {missing})I"; bst_float cond = tree[nid].SplitCond(); - return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); + return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth); } std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { @@ -477,7 +502,7 @@ class JsonGenerator : public TreeGenerator { R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I" R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" R"I("missing": {missing})I"; - return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); + return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth); } std::string NodeStat(RegTree const& tree, int32_t nid) const override { @@ -485,8 +510,8 @@ class JsonGenerator : public TreeGenerator { R"S(, "gain": {loss_chg}, "cover": {sum_hess})S"; auto result = SuperT::Match( kStatTemplate, - {{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)}, - {"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}); + {{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)}, + {"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}}); return result; } @@ -622,11 +647,11 @@ class GraphvizGenerator : public TreeGenerator { protected: template - std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const { + std::string BuildEdge(RegTree const &tree, bst_node_t nidx, int32_t child, bool left) const { static std::string const kEdgeTemplate = " {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n"; // Is this the default child for missing value? - bool is_missing = tree[nid].DefaultChild() == child; + bool is_missing = tree.DefaultChild(nidx) == child; std::string branch; if (is_categorical) { branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""}; @@ -635,7 +660,7 @@ class GraphvizGenerator : public TreeGenerator { } std::string buffer = SuperT::Match(kEdgeTemplate, - {{"{nid}", std::to_string(nid)}, + {{"{nid}", std::to_string(nidx)}, {"{child}", std::to_string(child)}, {"{color}", is_missing ? param_.yes_color : param_.no_color}, {"{branch}", branch}}); @@ -644,68 +669,77 @@ class GraphvizGenerator : public TreeGenerator { // Only indicator is different, so we combine all different node types into this // function. - std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override { - auto split_index = tree[nid].SplitIndex(); - auto cond = tree[nid].SplitCond(); + std::string PlainNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override { + auto split_index = tree.SplitIndex(nidx); + auto cond = tree.SplitCond(nidx); static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n"; bool has_less = (split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator; std::string result = - SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)}, + SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)}, {"{fname}", GetFeatureName(fmap_, split_index)}, {"{<}", has_less ? "<" : ""}, - {"{cond}", has_less ? SuperT::ToStr(cond) : ""}, + {"{cond}", has_less ? ToStr(cond) : ""}, {"{params}", param_.condition_node_params}}); - result += BuildEdge(tree, nid, tree[nid].LeftChild(), true); - result += BuildEdge(tree, nid, tree[nid].RightChild(), false); + result += BuildEdge(tree, nidx, tree.LeftChild(nidx), true); + result += BuildEdge(tree, nidx, tree.RightChild(nidx), false); return result; }; - std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override { + std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override { static std::string const kLabelTemplate = " {nid} [ label=\"{fname}:{cond}\" {params}]\n"; - auto cats = GetSplitCategories(tree, nid); + auto cats = GetSplitCategories(tree, nidx); auto cats_str = PrintCatsAsSet(cats); - auto split_index = tree[nid].SplitIndex(); + auto split_index = tree.SplitIndex(nidx); std::string result = - SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)}, + SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)}, {"{fname}", GetFeatureName(fmap_, split_index)}, {"{cond}", cats_str}, {"{params}", param_.condition_node_params}}); - result += BuildEdge(tree, nid, tree[nid].LeftChild(), true); - result += BuildEdge(tree, nid, tree[nid].RightChild(), false); + result += BuildEdge(tree, nidx, tree.LeftChild(nidx), true); + result += BuildEdge(tree, nidx, tree.RightChild(nidx), false); return result; } - std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override { - static std::string const kLeafTemplate = - " {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; - auto result = SuperT::Match(kLeafTemplate, { - {"{nid}", std::to_string(nid)}, - {"{leaf-value}", ToStr(tree[nid].LeafValue())}, - {"{params}", param_.leaf_node_params}}); - return result; - }; + std::string LeafNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override { + static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; + // hardcoded limit to avoid dumping long arrays into dot graph. + bst_target_t constexpr kLimit{3}; + if (tree.IsMultiTarget()) { + auto value = tree.GetMultiTargetTree()->LeafValue(nidx); + auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)}, + {"{leaf-value}", ToStr(value, kLimit)}, + {"{params}", param_.leaf_node_params}}); + return result; + } else { + auto value = tree[nidx].LeafValue(); + auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)}, + {"{leaf-value}", ToStr(value)}, + {"{params}", param_.leaf_node_params}}); + return result; + } + } - std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override { - if (tree[nid].IsLeaf()) { - return this->LeafNode(tree, nid, depth); + std::string BuildTree(RegTree const& tree, bst_node_t nidx, uint32_t depth) override { + if (tree.IsLeaf(nidx)) { + return this->LeafNode(tree, nidx, depth); } static std::string const kNodeTemplate = "{parent}\n{left}\n{right}"; - auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical - ? this->Categorical(tree, nid, depth) - : this->PlainNode(tree, nid, depth); + auto node = tree.GetSplitTypes()[nidx] == FeatureType::kCategorical + ? this->Categorical(tree, nidx, depth) + : this->PlainNode(tree, nidx, depth); auto result = SuperT::Match( kNodeTemplate, {{"{parent}", node}, - {"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)}, - {"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}}); + {"{left}", this->BuildTree(tree, tree.LeftChild(nidx), depth+1)}, + {"{right}", this->BuildTree(tree, tree.RightChild(nidx), depth+1)}}); return result; } @@ -733,7 +767,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot") constexpr bst_node_t RegTree::kRoot; std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const { - CHECK(!IsMultiTarget()); + if (this->IsMultiTarget() && format != "dot") { + LOG(FATAL) << format << " tree dump " << MTNotImplemented(); + } std::unique_ptr builder{TreeGenerator::Create(format, fmap, with_stats)}; builder->BuildTree(*this); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index ef166fae5132..45018da17adc 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -106,6 +106,9 @@ class ColMaker: public TreeUpdater { if (dmat->Info().HasCategorical()) { LOG(FATAL) << error::NoCategorical("Updater `grow_colmaker` or `exact` tree method"); } + if (param->colsample_bynode - 1.0 != 0.0) { + LOG(FATAL) << "column sample by node is not yet supported by the exact tree method"; + } this->LazyGetColumnDensity(dmat); // rescale learning rate according to size of trees interaction_constraints_.Configure(*param, dmat->Info().num_row_); @@ -440,9 +443,8 @@ class ColMaker: public TreeUpdater { } // update the solution candidate - virtual void UpdateSolution(const SortedCSCPage &batch, - const std::vector &feat_set, - const std::vector &gpair, DMatrix *) { + void UpdateSolution(SortedCSCPage const &batch, const std::vector &feat_set, + const std::vector &gpair) { // start enumeration const auto num_features = feat_set.size(); CHECK(this->ctx_); @@ -466,17 +468,15 @@ class ColMaker: public TreeUpdater { } }); } + // find splits at current level, do split per level - inline void FindSplit(int depth, - const std::vector &qexpand, - const std::vector &gpair, - DMatrix *p_fmat, - RegTree *p_tree) { + void FindSplit(bst_node_t depth, const std::vector &qexpand, + std::vector const &gpair, DMatrix *p_fmat, RegTree *p_tree) { auto evaluator = tree_evaluator_.GetEvaluator(); auto feat_set = column_sampler_->GetFeatureSet(depth); for (const auto &batch : p_fmat->GetBatches(ctx_)) { - this->UpdateSolution(batch, feat_set->HostVector(), gpair, p_fmat); + this->UpdateSolution(batch, feat_set->HostVector(), gpair); } // after this each thread's stemp will get the best candidates, aggregate results this->SyncBestSolution(qexpand); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 4911cec093c8..958fa0331569 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -13,7 +13,7 @@ #include #include "../collective/aggregator.h" -#include "../collective/aggregator.cuh" +#include "../collective/broadcast.h" #include "../common/bitfield.h" #include "../common/categorical.h" #include "../common/cuda_context.cuh" // CUDAContext @@ -191,7 +191,7 @@ struct GPUHistMakerDevice { std::unique_ptr feature_groups; GPUHistMakerDevice(Context const* ctx, bool is_external_memory, - common::Span _feature_types, bst_row_t _n_rows, + common::Span _feature_types, bst_idx_t _n_rows, TrainParam _param, std::shared_ptr column_sampler, uint32_t n_features, BatchParam batch_param, MetaInfo const& info) : evaluator_{_param, n_features, ctx->Device()}, @@ -410,11 +410,16 @@ struct GPUHistMakerDevice { } }); - collective::AllReduce( - ctx_->Ordinal(), decision_storage.data().get(), decision_storage.size()); - collective::AllReduce( - ctx_->Ordinal(), missing_storage.data().get(), missing_storage.size()); - collective::Synchronize(ctx_->Ordinal()); + auto rc = collective::Success() << [&] { + return collective::Allreduce( + ctx_, linalg::MakeTensorView(ctx_, dh::ToSpan(decision_storage), decision_storage.size()), + collective::Op::kBitwiseOR); + } << [&] { + return collective::Allreduce( + ctx_, linalg::MakeTensorView(ctx_, dh::ToSpan(missing_storage), missing_storage.size()), + collective::Op::kBitwiseAND); + }; + collective::SafeColl(rc); row_partitioner->UpdatePositionBatch( nidx, left_nidx, right_nidx, split_data, @@ -611,8 +616,11 @@ struct GPUHistMakerDevice { monitor.Start("AllReduce"); auto d_node_hist = hist.GetNodeHistogram(nidx).data(); using ReduceT = typename std::remove_pointer::type::ValueT; - collective::GlobalSum(info_, ctx_->Device(), reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * 2 * num_histograms); + auto rc = collective::GlobalSum( + ctx_, info_, + linalg::MakeVec(reinterpret_cast(d_node_hist), + page->Cuts().TotalBins() * 2 * num_histograms, ctx_->Device())); + SafeColl(rc); monitor.Stop("AllReduce"); } @@ -860,7 +868,9 @@ class GPUHistMaker : public TreeUpdater { // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); - collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); + auto rc = collective::Broadcast( + ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0); + SafeColl(rc); this->column_sampler_ = std::make_shared(column_sampling_seed); auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()}; @@ -1001,9 +1011,7 @@ class GPUGlobalApproxMaker : public TreeUpdater { monitor_.Start(__func__); CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal(); - // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); - collective::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); this->column_sampler_ = std::make_shared(column_sampling_seed); p_last_fmat_ = p_fmat; diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 941df7aec491..23c8ec9e6ea2 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023 by XGBoost Contributors + * Copyright 2014-2024, XGBoost Contributors * \file updater_refresh.cc * \brief refresh the statistics and leaf value on the tree on the dataset * \author Tianqi Chen @@ -9,8 +9,7 @@ #include #include -#include "../collective/communicator-inl.h" -#include "../common/io.h" +#include "../collective/allreduce.h" #include "../common/threading_utils.h" #include "../predictor/predict_fn.h" #include "./param.h" @@ -39,7 +38,7 @@ class TreeRefresher : public TreeUpdater { } CHECK_EQ(gpair->Shape(1), 1) << MTNotImplemented(); const std::vector &gpair_h = gpair->Data()->ConstHostVector(); - // thread temporal space + // Thread local variables. std::vector > stemp; std::vector fvec_temp; // setup temp space for each thread @@ -61,9 +60,8 @@ class TreeRefresher : public TreeUpdater { }); } exc.Rethrow(); - // if it is C++11, use lazy evaluation for Allreduce, - // to gain speedup in recovery - auto lazy_get_stats = [&]() { + + auto get_stats = [&]() { const MetaInfo &info = p_fmat->Info(); // start accumulating statistics for (const auto &batch : p_fmat->GetBatches()) { @@ -93,12 +91,17 @@ class TreeRefresher : public TreeUpdater { } }); }; - lazy_get_stats(); - collective::Allreduce(&dmlc::BeginPtr(stemp[0])->sum_grad, - stemp[0].size() * 2); - int offset = 0; + get_stats(); + // Synchronize the aggregated result. + auto &sum_grad = stemp[0]; + // x2 for gradient and hessian. + auto rc = collective::Allreduce( + ctx_, linalg::MakeVec(&sum_grad.data()->sum_grad, sum_grad.size() * 2), + collective::Op::kMax); + collective::SafeColl(rc); + bst_node_t offset = 0; for (auto tree : trees) { - this->Refresh(param, dmlc::BeginPtr(stemp[0]) + offset, 0, tree); + this->Refresh(param, dmlc::BeginPtr(sum_grad) + offset, 0, tree); offset += tree->NumNodes(); } } diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index f64f354837f6..6526e519c9d5 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -1,14 +1,14 @@ /** - * Copyright 2014-2023 by XBGoost Contributors + * Copyright 2014-2024, XBGoost Contributors * \file updater_sync.cc * \brief synchronize the tree in all distributed nodes */ #include -#include #include #include +#include "../collective/broadcast.h" #include "../collective/communicator-inl.h" #include "../common/io.h" #include "xgboost/json.h" @@ -44,7 +44,8 @@ class TreeSyncher : public TreeUpdater { } } fs.Seek(0); - collective::Broadcast(&s_model, 0); + auto rc = collective::Broadcast(ctx_, linalg::MakeVec(s_model.data(), s_model.size()), 0); + SafeColl(rc); for (auto tree : trees) { tree->Load(&fs); } diff --git a/tests/buildkite/conftest.sh b/tests/buildkite/conftest.sh index c6e8ef65a0e4..44043910bc96 100755 --- a/tests/buildkite/conftest.sh +++ b/tests/buildkite/conftest.sh @@ -24,7 +24,7 @@ set -x CUDA_VERSION=11.8.0 NCCL_VERSION=2.16.5-1 -RAPIDS_VERSION=24.02 +RAPIDS_VERSION=24.04 SPARK_VERSION=3.4.0 JDK_VERSION=8 R_VERSION=4.3.2 @@ -39,13 +39,14 @@ fi if [[ -n $BUILDKITE_PULL_REQUEST && $BUILDKITE_PULL_REQUEST != "false" ]] then is_pull_request=1 - export BRANCH_NAME=PR-$BUILDKITE_PULL_REQUEST + BRANCH_NAME=PR-$BUILDKITE_PULL_REQUEST else is_pull_request=0 - export BRANCH_NAME=$BUILDKITE_BRANCH + BRANCH_NAME=$BUILDKITE_BRANCH fi +export BRANCH_NAME=${BRANCH_NAME//\//-} -if [[ $BUILDKITE_BRANCH == "master" || $BUILDKITE_BRANCH == "release_"* ]] +if [[ $BRANCH_NAME == "master" || $BRANCH_NAME == "release_"* ]] then is_release_branch=1 enforce_daily_budget=0 diff --git a/tests/ci_build/Dockerfile.jvm b/tests/ci_build/Dockerfile.jvm index 43fbd8ff596d..a115fd52c2d9 100644 --- a/tests/ci_build/Dockerfile.jvm +++ b/tests/ci_build/Dockerfile.jvm @@ -15,9 +15,9 @@ RUN \ wget -nv -nc https://cmake.org/files/v3.18/cmake-3.18.0-Linux-x86_64.sh --no-check-certificate && \ bash cmake-3.18.0-Linux-x86_64.sh --skip-license --prefix=/usr && \ # Maven - wget -nv -nc https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \ - tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \ - ln -s /opt/apache-maven-3.6.1/ /opt/maven + wget -nv -nc https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz && \ + tar xvf apache-maven-3.6.3-bin.tar.gz -C /opt && \ + ln -s /opt/apache-maven-3.6.3/ /opt/maven ENV PATH=/opt/mambaforge/bin:/opt/maven/bin:$PATH ENV CC=/opt/rh/devtoolset-9/root/usr/bin/gcc diff --git a/tests/ci_build/Dockerfile.jvm_cross b/tests/ci_build/Dockerfile.jvm_cross index fdfae310aac5..43988872d989 100644 --- a/tests/ci_build/Dockerfile.jvm_cross +++ b/tests/ci_build/Dockerfile.jvm_cross @@ -1,6 +1,6 @@ FROM ubuntu:18.04 ARG JDK_VERSION=8 -ARG SPARK_VERSION=3.0.0 +ARG SPARK_VERSION=3.4.0 # Environment ENV DEBIAN_FRONTEND noninteractive @@ -17,9 +17,9 @@ RUN \ bash conda.sh -b -p /opt/mambaforge && \ /opt/mambaforge/bin/pip install awscli && \ # Maven - wget -nv https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \ - tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \ - ln -s /opt/apache-maven-3.6.1/ /opt/maven && \ + wget -nv https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz && \ + tar xvf apache-maven-3.6.3-bin.tar.gz -C /opt && \ + ln -s /opt/apache-maven-3.6.3/ /opt/maven && \ # Spark with scala 2.12 mkdir -p /opt/spark-scala-2.12 && \ wget -nv https://archive.apache.org/dist/spark/spark-$SPARK_VERSION/spark-$SPARK_VERSION-bin-hadoop3.tgz && \ diff --git a/tests/ci_build/Dockerfile.jvm_gpu_build b/tests/ci_build/Dockerfile.jvm_gpu_build index 86ce7e72a4b2..cee41894266b 100644 --- a/tests/ci_build/Dockerfile.jvm_gpu_build +++ b/tests/ci_build/Dockerfile.jvm_gpu_build @@ -18,9 +18,9 @@ RUN \ wget -nv -nc https://cmake.org/files/v3.18/cmake-3.18.0-Linux-x86_64.sh --no-check-certificate && \ bash cmake-3.18.0-Linux-x86_64.sh --skip-license --prefix=/usr && \ # Maven - wget -nv -nc https://archive.apache.org/dist/maven/maven-3/3.6.1/binaries/apache-maven-3.6.1-bin.tar.gz && \ - tar xvf apache-maven-3.6.1-bin.tar.gz -C /opt && \ - ln -s /opt/apache-maven-3.6.1/ /opt/maven + wget -nv -nc https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz && \ + tar xvf apache-maven-3.6.3-bin.tar.gz -C /opt && \ + ln -s /opt/apache-maven-3.6.3/ /opt/maven # NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html) RUN \ diff --git a/tests/ci_build/build_jvm_packages.sh b/tests/ci_build/build_jvm_packages.sh index 84b41f2b1021..97c056403f0a 100755 --- a/tests/ci_build/build_jvm_packages.sh +++ b/tests/ci_build/build_jvm_packages.sh @@ -18,7 +18,6 @@ fi rm -rf build/ cd jvm-packages -export RABIT_MOCK=ON if [ "x$gpu_arch" != "x" ]; then export GPU_ARCH_FLAG=$gpu_arch diff --git a/tests/ci_build/build_mock_cmake.sh b/tests/ci_build/build_mock_cmake.sh deleted file mode 100755 index 8cbabd036d97..000000000000 --- a/tests/ci_build/build_mock_cmake.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash -set -e - -rm -rf build -mkdir build -cd build -cmake -DRABIT_MOCK=ON -DCMAKE_VERBOSE_MAKEFILE=ON .. -make clean -make -j$(nproc) -cd .. diff --git a/tests/ci_build/conda_env/linux_sycl_test.yml b/tests/ci_build/conda_env/linux_sycl_test.yml index bb14c1e77ebb..7335b7f20fd5 100644 --- a/tests/ci_build/conda_env/linux_sycl_test.yml +++ b/tests/ci_build/conda_env/linux_sycl_test.yml @@ -18,3 +18,4 @@ dependencies: - pytest-timeout - pytest-cov - dpcpp_linux-64 +- onedpl-devel diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 741ef7558f13..d56191dc4566 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -98,6 +98,7 @@ class LintersPaths: "tests/test_distributed/test_gpu_with_spark/test_data.py", "tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py", # demo + "demo/dask/", "demo/json-model/json_parser.py", "demo/guide-python/external_memory.py", "demo/guide-python/sklearn_examples.py", diff --git a/tests/ci_build/test_r_package.py b/tests/ci_build/test_r_package.py index dd73f850bad0..1fe1644add1f 100644 --- a/tests/ci_build/test_r_package.py +++ b/tests/ci_build/test_r_package.py @@ -53,7 +53,6 @@ def pkgroot(path: str) -> None: # rabit rabit = Path("rabit") os.mkdir(dest / "src" / rabit) - shutil.copytree(rabit / "src", dest / "src" / "rabit" / "src") shutil.copytree(rabit / "include", dest / "src" / "rabit" / "include") # dmlc-core dmlc_core = Path("dmlc-core") @@ -277,6 +276,19 @@ def test_with_cmake(args: argparse.Namespace) -> None: "Release", ] ) + elif args.compiler == "none": + subprocess.check_call( + [ + "cmake", + os.path.pardir, + "-DUSE_OPENMP=ON", + "-DR_LIB=ON", + "-DCMAKE_CONFIGURATION_TYPES=Release", + "-G", + "Unix Makefiles", + ] + ) + subprocess.check_call(["make", "-j", "install"]) else: raise ValueError("Wrong compiler") with DirectoryExcursion(R_PACKAGE): @@ -333,9 +345,9 @@ def main(args: argparse.Namespace) -> None: parser.add_argument( "--compiler", type=str, - choices=["mingw", "msvc"], + choices=["mingw", "msvc", "none"], help="Compiler used for compiling CXX code. Only relevant for windows build", - default="mingw", + default="none", required=False, ) parser.add_argument( diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 20923519ac49..2748e13098b6 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -29,14 +29,14 @@ if(PLUGIN_SYCL) ${xgboost_SOURCE_DIR}/rabit/include) target_compile_definitions(plugin_sycl_test PUBLIC -DXGBOOST_USE_SYCL=1) - target_link_libraries(plugin_sycl_test PUBLIC -fsycl) + target_link_libraries(plugin_sycl_test PRIVATE ${GTEST_LIBRARIES}) set_target_properties(plugin_sycl_test PROPERTIES - COMPILE_FLAGS -fsycl - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON) + COMPILE_FLAGS -fsycl + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON) if(USE_OPENMP) find_package(OpenMP REQUIRED) set_target_properties(plugin_sycl_test PROPERTIES @@ -59,7 +59,6 @@ endif() target_sources(testxgboost PRIVATE ${TEST_SOURCES} ${xgboost_SOURCE_DIR}/plugin/example/custom_obj.cc) if(USE_CUDA AND PLUGIN_RMM) - find_package(CUDA) target_include_directories(testxgboost PRIVATE ${CUDA_INCLUDE_DIRS}) endif() @@ -71,7 +70,7 @@ target_include_directories(testxgboost ${xgboost_SOURCE_DIR}/rabit/include) target_link_libraries(testxgboost PRIVATE - ${GTEST_LIBRARIES}) + GTest::gtest GTest::gmock) set_output_directory(testxgboost ${xgboost_BINARY_DIR}) diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index c4c1f0c45f42..8729eba82fc3 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -434,7 +434,7 @@ void MakeLabelForTest(std::shared_ptr Xy, DMatrixHandle cxy) { XGDMatrixSetInfoFromInterface(cxy, "label", s_y_int.c_str()); } -auto MakeSimpleDMatrixForTest(bst_row_t n_samples, bst_feature_t n_features, Json dconfig) { +auto MakeSimpleDMatrixForTest(bst_idx_t n_samples, bst_feature_t n_features, Json dconfig) { HostDeviceVector storage; auto arr_int = RandomDataGenerator{n_samples, n_features, 0.5f}.GenerateArrayInterface(&storage); @@ -451,7 +451,7 @@ auto MakeSimpleDMatrixForTest(bst_row_t n_samples, bst_feature_t n_features, Jso return std::pair{p_fmat, Xy}; } -auto MakeQDMForTest(Context const *ctx, bst_row_t n_samples, bst_feature_t n_features, +auto MakeQDMForTest(Context const *ctx, bst_idx_t n_samples, bst_feature_t n_features, Json dconfig) { bst_bin_t n_bins{16}; dconfig["max_bin"] = Integer{n_bins}; @@ -483,7 +483,7 @@ auto MakeQDMForTest(Context const *ctx, bst_row_t n_samples, bst_feature_t n_fea return std::pair{p_fmat, Xy}; } -auto MakeExtMemForTest(bst_row_t n_samples, bst_feature_t n_features, Json dconfig) { +auto MakeExtMemForTest(bst_idx_t n_samples, bst_feature_t n_features, Json dconfig) { std::size_t n_batches{4}; NumpyArrayIterForTest iter_0{0.0f, n_samples, n_features, n_batches}; std::string s_dconfig; @@ -525,7 +525,7 @@ void CheckResult(Context const *ctx, bst_feature_t n_features, std::shared_ptr -#include - -#include // ifstream - -#include "../helpers.h" // for FileExists - -namespace xgboost::collective { -class SocketTest : public ::testing::Test { - protected: - std::string skip_msg_{"Skipping IPv6 test"}; - - bool SkipTest() { - std::string path{"/sys/module/ipv6/parameters/disable"}; - if (FileExists(path)) { - std::ifstream fin(path); - if (!fin) { - return true; - } - std::string s_value; - fin >> s_value; - auto value = std::stoi(s_value); - if (value != 0) { - return true; - } - } else { - return true; - } - return false; - } - - protected: - void SetUp() override { system::SocketStartup(); } - void TearDown() override { system::SocketFinalize(); } -}; -} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_allgather.cc b/tests/cpp/collective/test_allgather.cc index decad8786538..61e34cb573b9 100644 --- a/tests/cpp/collective/test_allgather.cc +++ b/tests/cpp/collective/test_allgather.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include // for ASSERT_EQ #include // for Span, oper... @@ -34,8 +34,8 @@ class Worker : public WorkerForTest { std::vector data(comm_.World(), 0); data[comm_.Rank()] = comm_.Rank(); - auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}, 1); - ASSERT_TRUE(rc.OK()) << rc.Report(); + auto rc = RingAllgather(this->comm_, common::Span{data.data(), data.size()}); + SafeColl(rc); for (std::int32_t r = 0; r < comm_.World(); ++r) { ASSERT_EQ(data[r], r); @@ -51,8 +51,8 @@ class Worker : public WorkerForTest { auto seg = s_data.subspan(comm_.Rank() * n, n); std::iota(seg.begin(), seg.end(), comm_.Rank()); - auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}, n); - ASSERT_TRUE(rc.OK()) << rc.Report(); + auto rc = RingAllgather(comm_, common::Span{data.data(), data.size()}); + SafeColl(rc); for (std::int32_t r = 0; r < comm_.World(); ++r) { auto seg = s_data.subspan(r * n, n); @@ -81,7 +81,7 @@ class Worker : public WorkerForTest { std::vector data(comm_.Rank() + 1, comm_.Rank()); std::vector result; auto rc = RingAllgatherV(comm_, common::Span{data.data(), data.size()}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(result.size(), (1 + comm_.World()) * comm_.World() / 2); CheckV(result); } @@ -91,7 +91,7 @@ class Worker : public WorkerForTest { std::int32_t n{comm_.Rank()}; std::vector result; auto rc = RingAllgatherV(comm_, common::Span{&n, 1}, &result); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t i = 0; i < comm_.World(); ++i) { ASSERT_EQ(result[i], i); } @@ -104,8 +104,8 @@ class Worker : public WorkerForTest { std::vector sizes(comm_.World(), 0); sizes[comm_.Rank()] = s_data.size_bytes(); - auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); - ASSERT_TRUE(rc.OK()) << rc.Report(); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); + SafeColl(rc); std::shared_ptr pcoll{new Coll{}}; std::vector recv_segments(comm_.World() + 1, 0); @@ -175,4 +175,35 @@ TEST_F(AllgatherTest, VAlgo) { worker.TestVAlgo(); }); } + +TEST(VectorAllgatherV, Basic) { + std::int32_t n_workers{3}; + TestDistributedGlobal(n_workers, []() { + auto n_workers = collective::GetWorldSize(); + ASSERT_EQ(n_workers, 3); + auto rank = collective::GetRank(); + // Construct input that has different length for each worker. + std::vector> inputs; + for (std::int32_t i = 0; i < rank + 1; ++i) { + std::vector in; + for (std::int32_t j = 0; j < rank + 1; ++j) { + in.push_back(static_cast(j)); + } + inputs.emplace_back(std::move(in)); + } + + Context ctx; + auto outputs = VectorAllgatherV(&ctx, inputs); + + ASSERT_EQ(outputs.size(), (1 + n_workers) * n_workers / 2); + auto const& res = outputs; + + for (std::int32_t i = 0; i < n_workers; ++i) { + std::int32_t k = 0; + for (auto v : res[i]) { + ASSERT_EQ(v, k++); + } + } + }); +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_allgather.cu b/tests/cpp/collective/test_allgather.cu index 2361081981ab..f145681da46a 100644 --- a/tests/cpp/collective/test_allgather.cu +++ b/tests/cpp/collective/test_allgather.cu @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include @@ -33,8 +33,8 @@ class Worker : public NCCLWorkerForTest { // get size std::vector sizes(comm_.World(), -1); sizes[comm_.Rank()] = s_data.size_bytes(); - auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); - ASSERT_TRUE(rc.OK()) << rc.Report(); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); + SafeColl(rc); // create result dh::device_vector result(comm_.World(), -1); auto s_result = common::EraseType(dh::ToSpan(result)); @@ -42,7 +42,7 @@ class Worker : public NCCLWorkerForTest { std::vector recv_seg(nccl_comm_->World() + 1, 0); rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::int32_t i = 0; i < comm_.World(); ++i) { ASSERT_EQ(result[i], i); @@ -57,8 +57,8 @@ class Worker : public NCCLWorkerForTest { // get size std::vector sizes(nccl_comm_->World(), 0); sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes(); - auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}, 1); - ASSERT_TRUE(rc.OK()) << rc.Report(); + auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()}); + SafeColl(rc); auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0); // create result dh::device_vector result(n_bytes / sizeof(std::int32_t), -1); @@ -67,7 +67,7 @@ class Worker : public NCCLWorkerForTest { std::vector recv_seg(nccl_comm_->World() + 1, 0); rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()}, common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); // check segment size if (algo != AllgatherVAlgo::kBcast) { auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()]; diff --git a/tests/cpp/collective/test_allreduce.cc b/tests/cpp/collective/test_allreduce.cc index 21b4d9fd0fe2..1ce2f35fd8ef 100644 --- a/tests/cpp/collective/test_allreduce.cc +++ b/tests/cpp/collective/test_allreduce.cc @@ -1,11 +1,12 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include +#include // for iota + #include "../../../src/collective/allreduce.h" #include "../../../src/collective/coll.h" // for Coll -#include "../../../src/collective/tracker.h" #include "../../../src/common/type.h" // for EraseType #include "test_worker.h" // for WorkerForTest, TestDistributed @@ -38,6 +39,22 @@ class AllreduceWorker : public WorkerForTest { } } + void Restricted() { + this->LimitSockBuf(4096); + + std::size_t n = 4096 * 4; + std::vector data(comm_.World() * n, 1); + auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { + for (std::size_t i = 0; i < rhs.size(); ++i) { + rhs[i] += lhs[i]; + } + }); + ASSERT_TRUE(rc.OK()); + for (auto v : data) { + ASSERT_EQ(v, comm_.World()); + } + } + void Acc() { std::vector data(314, 1.5); auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) { @@ -58,7 +75,7 @@ class AllreduceWorker : public WorkerForTest { auto pcoll = std::shared_ptr{new Coll{}}; auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (auto v : data) { ASSERT_EQ(v, ~std::uint32_t{0}); } @@ -94,4 +111,45 @@ TEST_F(AllreduceTest, BitOr) { worker.BitOr(); }); } + +TEST_F(AllreduceTest, Restricted) { + std::int32_t n_workers = std::min(3u, std::thread::hardware_concurrency()); + TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, + std::int32_t r) { + AllreduceWorker worker{host, port, timeout, n_workers, r}; + worker.Restricted(); + }); +} + +TEST(AllreduceGlobal, Basic) { + auto n_workers = 3; + TestDistributedGlobal(n_workers, [&]() { + std::vector values(n_workers * 2, 0); + auto rank = GetRank(); + auto s_values = common::Span{values.data(), values.size()}; + auto self = s_values.subspan(rank * 2, 2); + for (auto& v : self) { + v = 1.0f; + } + Context ctx; + auto rc = + Allreduce(&ctx, linalg::MakeVec(s_values.data(), s_values.size()), collective::Op::kSum); + SafeColl(rc); + for (auto v : s_values) { + ASSERT_EQ(v, 1); + } + }); +} + +TEST(AllreduceGlobal, Small) { + // Test when the data is not large enougth to be divided by the number of workers + auto n_workers = 8; + TestDistributedGlobal(n_workers, [&]() { + std::uint64_t value{1}; + Context ctx; + auto rc = Allreduce(&ctx, linalg::MakeVec(&value, 1), collective::Op::kSum); + SafeColl(rc); + ASSERT_EQ(value, n_workers); + }); +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_allreduce.cu b/tests/cpp/collective/test_allreduce.cu index 04ec9f773562..8bda1e0de10e 100644 --- a/tests/cpp/collective/test_allreduce.cu +++ b/tests/cpp/collective/test_allreduce.cu @@ -1,11 +1,11 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #if defined(XGBOOST_USE_NCCL) #include #include // for host_vector -#include "../../../src/common/common.h" +#include "../../../src/common/common.h" // for AllVisibleGPUs #include "../../../src/common/device_helpers.cuh" // for ToSpan, device_vector #include "../../../src/common/type.h" // for EraseType #include "test_worker.cuh" // for NCCLWorkerForTest @@ -24,7 +24,7 @@ class Worker : public NCCLWorkerForTest { data[comm_.Rank()] = ~std::uint32_t{0}; auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), ArrayInterfaceHandler::kU4, Op::kBitwiseOR); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); thrust::host_vector h_data(data.size()); thrust::copy(data.cbegin(), data.cend(), h_data.begin()); for (auto v : h_data) { @@ -36,7 +36,7 @@ class Worker : public NCCLWorkerForTest { dh::device_vector data(314, 1.5); auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)), ArrayInterfaceHandler::kF8, Op::kSum); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); for (std::size_t i = 0; i < data.size(); ++i) { auto v = data[i]; ASSERT_EQ(v, 1.5 * static_cast(comm_.World())) << i; diff --git a/tests/cpp/collective/test_broadcast.cc b/tests/cpp/collective/test_broadcast.cc index 4d0d87e93ae0..1b1d73428be1 100644 --- a/tests/cpp/collective/test_broadcast.cc +++ b/tests/cpp/collective/test_broadcast.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include #include @@ -10,7 +10,6 @@ #include // for vector #include "../../../src/collective/broadcast.h" // for Broadcast -#include "../../../src/collective/tracker.h" // for GetHostAddress #include "test_worker.h" // for WorkerForTest, TestDistributed namespace xgboost::collective { @@ -24,14 +23,14 @@ class Worker : public WorkerForTest { // basic test std::vector data(1, comm_.Rank()); auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(data[0], r); } for (std::int32_t r = 0; r < comm_.World(); ++r) { std::vector data(1 << 16, comm_.Rank()); auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(data[0], r); } } @@ -41,11 +40,11 @@ class BroadcastTest : public SocketTest {}; } // namespace TEST_F(BroadcastTest, Basic) { - std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency()); + std::int32_t n_workers = std::min(2u, std::thread::hardware_concurrency()); TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { Worker worker{host, port, timeout, n_workers, r}; worker.Run(); }); -} // namespace +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_coll_c_api.cc b/tests/cpp/collective/test_coll_c_api.cc index d80fbc14073d..c7229ff77eb1 100644 --- a/tests/cpp/collective/test_coll_c_api.cc +++ b/tests/cpp/collective/test_coll_c_api.cc @@ -25,13 +25,13 @@ TEST_F(TrackerAPITest, CAPI) { auto config_str = Json::Dump(config); auto rc = XGTrackerCreate(config_str.c_str(), &handle); ASSERT_EQ(rc, 0); - rc = XGTrackerRun(handle); + rc = XGTrackerRun(handle, nullptr); ASSERT_EQ(rc, 0); std::thread bg_wait{[&] { Json config{Object{}}; auto config_str = Json::Dump(config); - auto rc = XGTrackerWait(handle, config_str.c_str()); + auto rc = XGTrackerWaitFor(handle, config_str.c_str()); ASSERT_EQ(rc, 0); }}; @@ -42,8 +42,8 @@ TEST_F(TrackerAPITest, CAPI) { std::string host; ASSERT_TRUE(GetHostAddress(&host).OK()); - ASSERT_EQ(host, get(args["DMLC_TRACKER_URI"])); - auto port = get(args["DMLC_TRACKER_PORT"]); + ASSERT_EQ(host, get(args["dmlc_tracker_uri"])); + auto port = get(args["dmlc_tracker_port"]); ASSERT_NE(port, 0); std::vector workers; diff --git a/tests/cpp/collective/test_comm.cc b/tests/cpp/collective/test_comm.cc index 8e69b2f8e099..c1eb06465a66 100644 --- a/tests/cpp/collective/test_comm.cc +++ b/tests/cpp/collective/test_comm.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include @@ -14,7 +14,7 @@ class CommTest : public TrackerTest {}; TEST_F(CommTest, Channel) { auto n_workers = 4; - RabitTracker tracker{host, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; auto fut = tracker.Run(); std::vector workers; @@ -29,7 +29,7 @@ TEST_F(CommTest, Channel) { return p_chan->SendAll( EraseType(common::Span{&i, static_cast(1)})); } << [&] { return p_chan->Block(); }; - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } else { auto p_chan = worker.Comm().Chan(i - 1); std::int32_t r{-1}; @@ -37,7 +37,7 @@ TEST_F(CommTest, Channel) { return p_chan->RecvAll( EraseType(common::Span{&r, static_cast(1)})); } << [&] { return p_chan->Block(); }; - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(r, i - 1); } }); diff --git a/tests/cpp/collective/test_comm_group.cc b/tests/cpp/collective/test_comm_group.cc index 0f6bc23a277d..3b1b5c5df30e 100644 --- a/tests/cpp/collective/test_comm_group.cc +++ b/tests/cpp/collective/test_comm_group.cc @@ -17,17 +17,6 @@ namespace xgboost::collective { namespace { -auto MakeConfig(std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { - Json config{Object{}}; - config["dmlc_communicator"] = std::string{"rabit"}; - config["DMLC_TRACKER_URI"] = host; - config["DMLC_TRACKER_PORT"] = port; - config["dmlc_timeout_sec"] = static_cast(timeout.count()); - config["DMLC_TASK_ID"] = std::to_string(r); - config["dmlc_retry"] = 2; - return config; -} - class CommGroupTest : public SocketTest {}; } // namespace @@ -36,7 +25,7 @@ TEST_F(CommGroupTest, Basic) { TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { Context ctx; - auto config = MakeConfig(host, port, timeout, r); + auto config = MakeDistributedTestConfig(host, port, timeout, r); std::unique_ptr ptr{CommGroup::Create(config)}; ASSERT_TRUE(ptr->IsDistributed()); ASSERT_EQ(ptr->World(), n_workers); @@ -52,7 +41,7 @@ TEST_F(CommGroupTest, BasicGPU) { TestDistributed(n_workers, [&](std::string host, std::int32_t port, std::chrono::seconds timeout, std::int32_t r) { auto ctx = MakeCUDACtx(r); - auto config = MakeConfig(host, port, timeout, r); + auto config = MakeDistributedTestConfig(host, port, timeout, r); std::unique_ptr ptr{CommGroup::Create(config)}; auto const& comm = ptr->Ctx(&ctx, DeviceOrd::CUDA(0)); ASSERT_EQ(comm.TaskID(), std::to_string(r)); diff --git a/tests/cpp/collective/test_communicator.cc b/tests/cpp/collective/test_communicator.cc deleted file mode 100644 index a0ca9e747b2b..000000000000 --- a/tests/cpp/collective/test_communicator.cc +++ /dev/null @@ -1,63 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include - -#include "../../../src/collective/communicator.h" - -namespace xgboost { -namespace collective { - -TEST(CommunicatorFactory, TypeFromEnv) { - EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "foo"); - EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "rabit"); - EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "Federated"); - EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv()); - - dmlc::SetEnv("XGBOOST_COMMUNICATOR", "In-Memory"); - EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromEnv()); -} - -TEST(CommunicatorFactory, TypeFromArgs) { - Json config{JsonObject()}; - EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("rabit"); - EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("federated"); - EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("in-memory"); - EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config)); - - config["xgboost_communicator"] = String("foo"); - EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error); -} - -TEST(CommunicatorFactory, TypeFromArgsUpperCase) { - Json config{JsonObject()}; - EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("rabit"); - EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("federated"); - EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("in-memory"); - EXPECT_EQ(CommunicatorType::kInMemory, Communicator::GetTypeFromConfig(config)); - - config["XGBOOST_COMMUNICATOR"] = String("foo"); - EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error); -} - -} // namespace collective -} // namespace xgboost diff --git a/tests/cpp/collective/test_in_memory_communicator.cc b/tests/cpp/collective/test_in_memory_communicator.cc deleted file mode 100644 index 69c427a4e642..000000000000 --- a/tests/cpp/collective/test_in_memory_communicator.cc +++ /dev/null @@ -1,237 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include - -#include -#include - -#include "../../../src/collective/in_memory_communicator.h" - -namespace xgboost { -namespace collective { - -class InMemoryCommunicatorTest : public ::testing::Test { - public: - static void Verify(void (*function)(int)) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(function, rank); - } - for (auto &thread : threads) { - thread.join(); - } - } - - static void Allgather(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllgather(comm, rank); - } - - static void AllgatherV(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllgatherV(comm, rank); - } - - static void AllreduceMax(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceMax(comm, rank); - } - - static void AllreduceMin(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceMin(comm, rank); - } - - static void AllreduceSum(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceSum(comm); - } - - static void AllreduceBitwiseAND(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceBitwiseAND(comm, rank); - } - - static void AllreduceBitwiseOR(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceBitwiseOR(comm, rank); - } - - static void AllreduceBitwiseXOR(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyAllreduceBitwiseXOR(comm, rank); - } - - static void Broadcast(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - VerifyBroadcast(comm, rank); - } - - static void Mixture(int rank) { - InMemoryCommunicator comm{kWorldSize, rank}; - for (auto i = 0; i < 5; i++) { - VerifyAllgather(comm, rank); - VerifyAllreduceMax(comm, rank); - VerifyAllreduceMin(comm, rank); - VerifyAllreduceSum(comm); - VerifyAllreduceBitwiseAND(comm, rank); - VerifyAllreduceBitwiseOR(comm, rank); - VerifyAllreduceBitwiseXOR(comm, rank); - VerifyBroadcast(comm, rank); - } - } - - protected: - static void VerifyAllgather(InMemoryCommunicator &comm, int rank) { - std::string input{static_cast('0' + rank)}; - auto output = comm.AllGather(input); - for (auto i = 0; i < kWorldSize; i++) { - EXPECT_EQ(output[i], static_cast('0' + i)); - } - } - - static void VerifyAllgatherV(InMemoryCommunicator &comm, int rank) { - std::vector inputs{"a", "bb", "ccc"}; - auto output = comm.AllGatherV(inputs[rank]); - EXPECT_EQ(output, "abbccc"); - } - - static void VerifyAllreduceMax(InMemoryCommunicator &comm, int rank) { - int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank}; - comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax); - int expected[] = {3, 4, 5, 6, 7}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(buffer[i], expected[i]); - } - } - - static void VerifyAllreduceMin(InMemoryCommunicator &comm, int rank) { - int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank}; - comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMin); - int expected[] = {1, 2, 3, 4, 5}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(buffer[i], expected[i]); - } - } - - static void VerifyAllreduceSum(InMemoryCommunicator &comm) { - int buffer[] = {1, 2, 3, 4, 5}; - comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); - int expected[] = {3, 6, 9, 12, 15}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(buffer[i], expected[i]); - } - } - - static void VerifyAllreduceBitwiseAND(InMemoryCommunicator &comm, int rank) { - std::bitset<2> original(rank); - auto buffer = original.to_ulong(); - comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseAND); - EXPECT_EQ(buffer, 0UL); - } - - static void VerifyAllreduceBitwiseOR(InMemoryCommunicator &comm, int rank) { - std::bitset<2> original(rank); - auto buffer = original.to_ulong(); - comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseOR); - std::bitset<2> actual(buffer); - std::bitset<2> expected{0b11}; - EXPECT_EQ(actual, expected); - } - - static void VerifyAllreduceBitwiseXOR(InMemoryCommunicator &comm, int rank) { - std::bitset<3> original(rank * 2); - auto buffer = original.to_ulong(); - comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseXOR); - std::bitset<3> actual(buffer); - std::bitset<3> expected{0b110}; - EXPECT_EQ(actual, expected); - } - - static void VerifyBroadcast(InMemoryCommunicator &comm, int rank) { - if (rank == 0) { - std::string buffer{"hello"}; - comm.Broadcast(&buffer[0], buffer.size(), 0); - EXPECT_EQ(buffer, "hello"); - } else { - std::string buffer{" "}; - comm.Broadcast(&buffer[0], buffer.size(), 0); - EXPECT_EQ(buffer, "hello"); - } - } - - static int const kWorldSize{3}; -}; - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { - auto construct = []() { InMemoryCommunicator comm{0, 0}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooSmall) { - auto construct = []() { InMemoryCommunicator comm{1, -1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooBig) { - auto construct = []() { InMemoryCommunicator comm{1, 1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { - auto construct = []() { - Json config{JsonObject()}; - config["in_memory_world_size"] = std::string("1"); - config["in_memory_rank"] = Integer(0); - auto *comm = InMemoryCommunicator::Create(config); - delete comm; - }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankNotInteger) { - auto construct = []() { - Json config{JsonObject()}; - config["in_memory_world_size"] = 1; - config["in_memory_rank"] = std::string("0"); - auto *comm = InMemoryCommunicator::Create(config); - delete comm; - }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(InMemoryCommunicatorSimpleTest, GetWorldSizeAndRank) { - InMemoryCommunicator comm{1, 0}; - EXPECT_EQ(comm.GetWorldSize(), 1); - EXPECT_EQ(comm.GetRank(), 0); -} - -TEST(InMemoryCommunicatorSimpleTest, IsDistributed) { - InMemoryCommunicator comm{1, 0}; - EXPECT_TRUE(comm.IsDistributed()); -} - -TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); } - -TEST_F(InMemoryCommunicatorTest, AllgatherV) { Verify(&AllgatherV); } - -TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); } - -TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); } - -TEST_F(InMemoryCommunicatorTest, AllreduceSum) { Verify(&AllreduceSum); } - -TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseAND) { Verify(&AllreduceBitwiseAND); } - -TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseOR) { Verify(&AllreduceBitwiseOR); } - -TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseXOR) { Verify(&AllreduceBitwiseXOR); } - -TEST_F(InMemoryCommunicatorTest, Broadcast) { Verify(&Broadcast); } - -TEST_F(InMemoryCommunicatorTest, Mixture) { Verify(&Mixture); } - -} // namespace collective -} // namespace xgboost diff --git a/tests/cpp/collective/test_loop.cc b/tests/cpp/collective/test_loop.cc index e5ef987f3fd5..622b350aaae8 100644 --- a/tests/cpp/collective/test_loop.cc +++ b/tests/cpp/collective/test_loop.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include // for ASSERT_TRUE, ASSERT_EQ #include // for TCPSocket, Connect, SocketFinalize, SocketStartup @@ -28,18 +28,23 @@ class LoopTest : public ::testing::Test { auto domain = SockDomain::kV4; pair_.first = TCPSocket::Create(domain); - auto port = pair_.first.BindHost(); - pair_.first.Listen(); + std::int32_t port{0}; + auto rc = Success() << [&] { + return pair_.first.BindHost(&port); + } << [&] { + return pair_.first.Listen(); + }; + SafeColl(rc); auto const& addr = SockAddrV4::Loopback().Addr(); - auto rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second); - ASSERT_TRUE(rc.OK()); + rc = Connect(StringView{addr}, port, 1, timeout, &pair_.second); + SafeColl(rc); rc = pair_.second.NonBlocking(true); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); pair_.first = pair_.first.Accept(); rc = pair_.first.NonBlocking(true); - ASSERT_TRUE(rc.OK()); + SafeColl(rc); loop_ = std::shared_ptr{new Loop{timeout}}; } @@ -54,7 +59,7 @@ class LoopTest : public ::testing::Test { TEST_F(LoopTest, Timeout) { std::vector data(1); Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &pair_.second, 0}; - loop_->Submit(op); + loop_->Submit(std::move(op)); auto rc = loop_->Block(); ASSERT_FALSE(rc.OK()); ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)) << rc.Report(); @@ -70,12 +75,30 @@ TEST_F(LoopTest, Op) { Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0}; Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0}; - loop_->Submit(wop); - loop_->Submit(rop); + loop_->Submit(std::move(wop)); + loop_->Submit(std::move(rop)); auto rc = loop_->Block(); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); ASSERT_EQ(rbuf[0], wbuf[0]); } + +TEST_F(LoopTest, Block) { + // We need to ensure that a blocking call doesn't go unanswered. + auto op = Loop::Op::Sleep(2); + + common::Timer t; + t.Start(); + loop_->Submit(std::move(op)); + t.Stop(); + // submit is non-blocking + ASSERT_LT(t.ElapsedSeconds(), 1); + + t.Start(); + auto rc = loop_->Block(); + t.Stop(); + SafeColl(rc); + ASSERT_GE(t.ElapsedSeconds(), 1); +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu deleted file mode 100644 index 47e86220d7e8..000000000000 --- a/tests/cpp/collective/test_nccl_device_communicator.cu +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2022-2023, XGBoost contributors - */ -#ifdef XGBOOST_USE_NCCL - -#include - -#include -#include // for string - -#include "../../../src/collective/comm.cuh" -#include "../../../src/collective/communicator-inl.cuh" -#include "../../../src/collective/nccl_device_communicator.cuh" -#include "../helpers.h" - -namespace xgboost { -namespace collective { - -TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) { - auto construct = []() { NcclDeviceCommunicator comm{-1, false, DefaultNcclName()}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(NcclDeviceCommunicatorSimpleTest, SystemError) { - auto stub = std::make_shared(DefaultNcclName()); - auto rc = stub->GetNcclResult(ncclSystemError); - auto msg = rc.Report(); - ASSERT_TRUE(msg.find("environment variables") != std::string::npos); -} - -namespace { -void VerifyAllReduceBitwiseAND() { - auto const rank = collective::GetRank(); - std::bitset<64> original{}; - original[rank] = true; - HostDeviceVector buffer({original.to_ullong()}, DeviceOrd::CUDA(rank)); - collective::AllReduce(rank, buffer.DevicePointer(), 1); - collective::Synchronize(rank); - EXPECT_EQ(buffer.HostVector()[0], 0ULL); -} -} // anonymous namespace - -TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseAND) { - auto const n_gpus = common::AllVisibleGPUs(); - if (n_gpus <= 1) { - GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseAND test with # GPUs = " << n_gpus; - } - auto constexpr kUseNccl = true; - RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseAND); -} - -namespace { -void VerifyAllReduceBitwiseOR() { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - std::bitset<64> original{}; - original[rank] = true; - HostDeviceVector buffer({original.to_ullong()}, DeviceOrd::CUDA(rank)); - collective::AllReduce(rank, buffer.DevicePointer(), 1); - collective::Synchronize(rank); - EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1); -} -} // anonymous namespace - -TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseOR) { - auto const n_gpus = common::AllVisibleGPUs(); - if (n_gpus <= 1) { - GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseOR test with # GPUs = " << n_gpus; - } - auto constexpr kUseNccl = true; - RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseOR); -} - -namespace { -void VerifyAllReduceBitwiseXOR() { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - std::bitset<64> original{~0ULL}; - original[rank] = false; - HostDeviceVector buffer({original.to_ullong()}, DeviceOrd::CUDA(rank)); - collective::AllReduce(rank, buffer.DevicePointer(), 1); - collective::Synchronize(rank); - EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1); -} -} // anonymous namespace - -TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseXOR) { - auto const n_gpus = common::AllVisibleGPUs(); - if (n_gpus <= 1) { - GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseXOR test with # GPUs = " << n_gpus; - } - auto constexpr kUseNccl = true; - RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseXOR); -} - -} // namespace collective -} // namespace xgboost - -#endif // XGBOOST_USE_NCCL diff --git a/tests/cpp/collective/test_rabit_communicator.cc b/tests/cpp/collective/test_rabit_communicator.cc deleted file mode 100644 index 9711e1aede71..000000000000 --- a/tests/cpp/collective/test_rabit_communicator.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2022-2024, XGBoost contributors - */ -#include - -#include "../../../src/collective/rabit_communicator.h" -#include "../helpers.h" - -namespace xgboost::collective { -TEST(RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { - auto construct = []() { RabitCommunicator comm{0, 0}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooSmall) { - auto construct = []() { RabitCommunicator comm{1, -1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooBig) { - auto construct = []() { RabitCommunicator comm{1, 1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(RabitCommunicatorSimpleTest, GetWorldSizeAndRank) { - RabitCommunicator comm{6, 3}; - EXPECT_EQ(comm.GetWorldSize(), 6); - EXPECT_EQ(comm.GetRank(), 3); -} - -TEST(RabitCommunicatorSimpleTest, IsNotDistributed) { - RabitCommunicator comm{2, 1}; - // Rabit is only distributed with a tracker. - EXPECT_FALSE(comm.IsDistributed()); -} - -namespace { -void VerifyVectorAllgatherV() { - auto n_workers = collective::GetWorldSize(); - ASSERT_EQ(n_workers, 3); - auto rank = collective::GetRank(); - // Construct input that has different length for each worker. - std::vector> inputs; - for (std::int32_t i = 0; i < rank + 1; ++i) { - std::vector in; - for (std::int32_t j = 0; j < rank + 1; ++j) { - in.push_back(static_cast(j)); - } - inputs.emplace_back(std::move(in)); - } - - auto outputs = VectorAllgatherV(inputs); - - ASSERT_EQ(outputs.size(), (1 + n_workers) * n_workers / 2); - auto const& res = outputs; - - for (std::int32_t i = 0; i < n_workers; ++i) { - std::int32_t k = 0; - for (auto v : res[i]) { - ASSERT_EQ(v, k++); - } - } -} -} // namespace - -TEST(VectorAllgatherV, Basic) { - std::int32_t n_workers{3}; - RunWithInMemoryCommunicator(n_workers, VerifyVectorAllgatherV); -} -} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_result.cc b/tests/cpp/collective/test_result.cc new file mode 100644 index 000000000000..1c7194f92f5c --- /dev/null +++ b/tests/cpp/collective/test_result.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include +#include + +namespace xgboost::collective { +TEST(Result, Concat) { + auto rc0 = Fail("foo"); + auto rc1 = Fail("bar"); + auto rc = std::move(rc0) + std::move(rc1); + ASSERT_NE(rc.Report().find("foo"), std::string::npos); + ASSERT_NE(rc.Report().find("bar"), std::string::npos); + + auto rc2 = Fail("Another", std::move(rc)); + auto assert_that = [](Result const& rc) { + ASSERT_NE(rc.Report().find("Another"), std::string::npos); + ASSERT_NE(rc.Report().find("foo"), std::string::npos); + ASSERT_NE(rc.Report().find("bar"), std::string::npos); + }; + assert_that(rc2); + + auto empty = Success(); + auto rc3 = std::move(empty) + std::move(rc2); + assert_that(rc3); + + empty = Success(); + auto rc4 = std::move(rc3) + std::move(empty); + assert_that(rc4); +} +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc index ced795fef9a9..ea57da9b4a43 100644 --- a/tests/cpp/collective/test_socket.cc +++ b/tests/cpp/collective/test_socket.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023, XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include #include @@ -21,14 +21,19 @@ TEST_F(SocketTest, Basic) { auto run_test = [msg](SockDomain domain) { auto server = TCPSocket::Create(domain); ASSERT_EQ(server.Domain(), domain); - auto port = server.BindHost(); - server.Listen(); + std::int32_t port{0}; + auto rc = Success() << [&] { + return server.BindHost(&port); + } << [&] { + return server.Listen(); + }; + SafeColl(rc); TCPSocket client; if (domain == SockDomain::kV4) { auto const& addr = SockAddrV4::Loopback().Addr(); auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } else { auto const& addr = SockAddrV6::Loopback().Addr(); auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client); @@ -45,7 +50,8 @@ TEST_F(SocketTest, Basic) { accepted.Send(msg); std::string str; - client.Recv(&str); + rc = client.Recv(&str); + SafeColl(rc); ASSERT_EQ(StringView{str}, msg); }; diff --git a/tests/cpp/collective/test_tracker.cc b/tests/cpp/collective/test_tracker.cc index 0dce33c0cc62..e31e2662854e 100644 --- a/tests/cpp/collective/test_tracker.cc +++ b/tests/cpp/collective/test_tracker.cc @@ -1,6 +1,7 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ +#include #include #include // for seconds @@ -10,6 +11,7 @@ #include // for vector #include "../../../src/collective/comm.h" +#include "../helpers.h" // for GMockThrow #include "test_worker.h" namespace xgboost::collective { @@ -20,13 +22,14 @@ class PrintWorker : public WorkerForTest { void Print() { auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank())); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } }; } // namespace TEST_F(TrackerTest, Bootstrap) { - RabitTracker tracker{host, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; + ASSERT_TRUE(HasTimeout(tracker.Timeout())); ASSERT_FALSE(tracker.Ready()); auto fut = tracker.Run(); @@ -34,7 +37,7 @@ TEST_F(TrackerTest, Bootstrap) { auto args = tracker.WorkerArgs(); ASSERT_TRUE(tracker.Ready()); - ASSERT_EQ(get(args["DMLC_TRACKER_URI"]), host); + ASSERT_EQ(get(args["dmlc_tracker_uri"]), host); std::int32_t port = tracker.Port(); @@ -44,12 +47,14 @@ TEST_F(TrackerTest, Bootstrap) { for (auto &w : workers) { w.join(); } + SafeColl(fut.get()); - ASSERT_TRUE(fut.get().OK()); + ASSERT_FALSE(HasTimeout(std::chrono::seconds{-1})); + ASSERT_FALSE(HasTimeout(std::chrono::seconds{0})); } TEST_F(TrackerTest, Print) { - RabitTracker tracker{host, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; auto fut = tracker.Run(); std::vector workers; @@ -73,4 +78,47 @@ TEST_F(TrackerTest, Print) { } TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); } + +/** + * Test connecting the tracker after it has finished. This should not hang the workers. + */ +TEST_F(TrackerTest, AfterShutdown) { + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; + auto fut = tracker.Run(); + + std::vector workers; + auto rc = tracker.WaitUntilReady(); + ASSERT_TRUE(rc.OK()); + + std::int32_t port = tracker.Port(); + + // Launch no-op workers to cause the tracker to shutdown. + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; }); + } + + for (auto &w : workers) { + w.join(); + } + + ASSERT_TRUE(fut.get().OK()); + + // Launch workers again, they should fail. + workers.clear(); + for (std::int32_t i = 0; i < n_workers; ++i) { + auto assert_that = [=] { + WorkerForTest worker{host, port, timeout, n_workers, i}; + }; + // On a Linux platform, the connection will be refused, on Apple platform, this gets + // an operation now in progress poll failure, on Windows, it's a timeout error. +#if defined(__linux__) + workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Connection refused")); }); +#else + workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Failed to connect to")); }); +#endif + } + for (auto &w : workers) { + w.join(); + } +} } // namespace xgboost::collective diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index acee0f2970ca..f1889200b4d6 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -1,19 +1,26 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include #include // for seconds #include // for int32_t +#include // for ifstream #include // for string #include // for thread #include // for move #include // for vector #include "../../../src/collective/comm.h" -#include "../../../src/collective/tracker.h" // for GetHostAddress -#include "../helpers.h" // for FileExists +#include "../../../src/collective/communicator-inl.h" // for Init, Finalize +#include "../../../src/collective/tracker.h" // for GetHostAddress +#include "../../../src/common/common.h" // for AllVisibleGPUs +#include "../helpers.h" // for FileExists + +#if defined(XGBOOST_USE_FEDERATED) +#include "../plugin/federated/test_worker.h" +#endif // defined(XGBOOST_USE_FEDERATED) namespace xgboost::collective { class WorkerForTest { @@ -36,7 +43,7 @@ class WorkerForTest { comm_{tracker_host_, tracker_port_, timeout, retry_, task_id_, DefaultNcclName()} { CHECK_EQ(world_size_, comm_.World()); } - virtual ~WorkerForTest() = default; + virtual ~WorkerForTest() noexcept(false) { SafeColl(comm_.Shutdown()); } auto& Comm() { return comm_; } void LimitSockBuf(std::int32_t n_bytes) { @@ -44,6 +51,7 @@ class WorkerForTest { if (i != comm_.Rank()) { ASSERT_TRUE(comm_.Chan(i)->Socket()->NonBlocking()); ASSERT_TRUE(comm_.Chan(i)->Socket()->SetBufSize(n_bytes).OK()); + ASSERT_TRUE(comm_.Chan(i)->Socket()->SetNoDelay().OK()); } } } @@ -86,19 +94,30 @@ class TrackerTest : public SocketTest { void SetUp() override { SocketTest::SetUp(); auto rc = GetHostAddress(&host); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); } }; +inline Json MakeTrackerConfig(std::string host, std::int32_t n_workers, + std::chrono::seconds timeout) { + Json config{Object{}}; + config["host"] = host; + config["port"] = Integer{0}; + config["n_workers"] = Integer{n_workers}; + config["sortby"] = Integer{static_cast(Tracker::SortBy::kHost)}; + config["timeout"] = timeout.count(); + return config; +} + template void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { std::chrono::seconds timeout{2}; std::string host; auto rc = GetHostAddress(&host); - ASSERT_TRUE(rc.OK()) << rc.Report(); + SafeColl(rc); LOG(INFO) << "Using " << n_workers << " workers for test."; - RabitTracker tracker{StringView{host}, n_workers, 0, timeout}; + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; auto fut = tracker.Run(); std::vector workers; @@ -114,4 +133,86 @@ void TestDistributed(std::int32_t n_workers, WorkerFn worker_fn) { ASSERT_TRUE(fut.get().OK()); } + +inline auto MakeDistributedTestConfig(std::string host, std::int32_t port, + std::chrono::seconds timeout, std::int32_t r) { + Json config{Object{}}; + config["dmlc_communicator"] = std::string{"rabit"}; + config["dmlc_tracker_uri"] = host; + config["dmlc_tracker_port"] = port; + config["dmlc_timeout"] = static_cast(timeout.count()); + config["dmlc_task_id"] = std::to_string(r); + config["dmlc_retry"] = 2; + return config; +} + +template +void TestDistributedGlobal(std::int32_t n_workers, WorkerFn worker_fn, bool need_finalize = true, + std::chrono::seconds test_timeout = std::chrono::seconds{30}) { + system::SocketStartup(); + std::chrono::seconds timeout{1}; + + std::string host; + auto rc = GetHostAddress(&host); + SafeColl(rc); + + RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)}; + auto fut = tracker.Run(); + + std::vector workers; + std::int32_t port = tracker.Port(); + + for (std::int32_t i = 0; i < n_workers; ++i) { + workers.emplace_back([=] { + auto fut = std::async(std::launch::async, [=] { + auto config = MakeDistributedTestConfig(host, port, timeout, i); + Init(config); + worker_fn(); + if (need_finalize) { + Finalize(); + } + }); + auto status = fut.wait_for(test_timeout); + CHECK(status == std::future_status::ready) << "Test timeout"; + fut.get(); + }); + } + + for (auto& t : workers) { + t.join(); + } + + ASSERT_TRUE(fut.get().OK()); + system::SocketFinalize(); +} + +class BaseMGPUTest : public ::testing::Test { + public: + /** + * @param emulate_if_single Emulate multi-GPU for federated test if there's only one GPU + * available. + */ + template + auto DoTest(Fn&& fn, bool is_federated, bool emulate_if_single = false) const { + auto n_gpus = common::AllVisibleGPUs(); + if (is_federated) { +#if defined(XGBOOST_USE_FEDERATED) + if (n_gpus == 1 && emulate_if_single) { + // Emulate multiple GPUs. + // We don't use nccl and can have multiple communicators running on the same device. + n_gpus = 3; + } + TestFederatedGlobal(n_gpus, fn); +#else + GTEST_SKIP_("Not compiled with federated learning."); +#endif // defined(XGBOOST_USE_FEDERATED) + } else { +#if defined(XGBOOST_USE_NCCL) + TestDistributedGlobal(n_gpus, fn); +#else + GTEST_SKIP_("Not compiled with NCCL."); +#endif // defined(XGBOOST_USE_NCCL) + } + } +}; } // namespace xgboost::collective diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 1d10a48ad1a5..4178e55d8fd8 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -1,14 +1,16 @@ -/*! - * Copyright 2017-2021 XGBoost contributors +/** + * Copyright 2017-2024, XGBoost contributors */ +#include +#include // for is_sorted +#include + #include #include -#include #include -#include + #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/quantile.h" -#include "../helpers.h" #include "gtest/gtest.h" TEST(SumReduce, Test) { diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 5391bc2cfa48..24e67c9aa4e6 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -1,10 +1,9 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include #include -#include #include "../../../src/common/hist_util.h" #include "../../../src/data/gradient_index.h" @@ -135,7 +134,7 @@ TEST(CutsBuilder, SearchGroupInd) { group[2] = 7; group[3] = 5; - p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups); + p_mat->SetInfo("group", Make1dInterfaceTest(group.data(), group.size())); HistogramCuts hmat; @@ -348,7 +347,8 @@ void TestSketchFromWeights(bool with_group) { for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } - info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); + auto sg = linalg::Make1dInterface(groups.data(), kGroups); + info.SetInfo(ctx, "group", sg.c_str()); } info.num_row_ = kRows; @@ -356,10 +356,10 @@ void TestSketchFromWeights(bool with_group) { // Assign weights. if (with_group) { - m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + m->SetInfo("group", Make1dInterfaceTest(groups.data(), kGroups)); } - m->SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size()); + m->SetInfo("weight", Make1dInterfaceTest(h_weights.data(), h_weights.size())); m->Info().num_col_ = kCols; m->Info().num_row_ = kRows; ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index c0d5c5ddc109..e37f02ddb2c9 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include @@ -179,7 +179,7 @@ void TestMixedSketch() { TEST(HistUtil, DeviceSketchMixedFeatures) { TestMixedSketch(); } TEST(HistUtil, RemoveDuplicatedCategories) { - bst_row_t n_samples = 512; + bst_idx_t n_samples = 512; bst_feature_t n_features = 3; bst_cat_t n_categories = 5; @@ -208,13 +208,13 @@ TEST(HistUtil, RemoveDuplicatedCategories) { FeatureType::kNumerical, FeatureType::kCategorical, FeatureType::kNumerical}; ASSERT_EQ(info.feature_types.Size(), n_features); - HostDeviceVector cuts_ptr{0, n_samples, n_samples * 2, n_samples * 3}; + HostDeviceVector cuts_ptr{0, n_samples, n_samples * 2, n_samples * 3}; cuts_ptr.SetDevice(DeviceOrd::CUDA(0)); dh::device_vector weight(n_samples * n_features, 0); dh::Iota(dh::ToSpan(weight), ctx.CUDACtx()->Stream()); - dh::caching_device_vector columns_ptr(4); + dh::caching_device_vector columns_ptr(4); for (std::size_t i = 0; i < columns_ptr.size(); ++i) { columns_ptr[i] = i * n_samples; } @@ -639,7 +639,7 @@ void TestGetColumnSize(std::size_t n_samples) { } // namespace TEST(HistUtil, GetColumnSize) { - bst_row_t n_samples = 4096; + bst_idx_t n_samples = 4096; TestGetColumnSize(n_samples); } @@ -682,7 +682,7 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) { for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } - m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + m->SetInfo("group", Make1dInterfaceTest(groups.data(), kGroups)); HistogramCuts weighted_cuts = DeviceSketch(&ctx, m.get(), kBins, 0); // sketch with no weight @@ -727,7 +727,7 @@ void TestAdapterSketchFromWeights(bool with_group) { for (size_t i = 0; i < kGroups; ++i) { groups[i] = kRows / kGroups; } - info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); + info.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), kGroups)); } info.weights_.SetDevice(DeviceOrd::CUDA(0)); @@ -746,10 +746,10 @@ void TestAdapterSketchFromWeights(bool with_group) { auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); if (with_group) { - dmat->Info().SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); + dmat->Info().SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), kGroups)); } - dmat->Info().SetInfo(ctx, "weight", h_weights.data(), DataType::kFloat32, h_weights.size()); + dmat->Info().SetInfo(ctx, "weight", Make1dInterfaceTest(h_weights.data(), h_weights.size())); dmat->Info().num_col_ = kCols; dmat->Info().num_row_ = kRows; ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); @@ -795,11 +795,11 @@ TEST(HistUtil, AdapterSketchFromWeights) { namespace { class DeviceSketchWithHessianTest - : public ::testing::TestWithParam> { + : public ::testing::TestWithParam> { bst_feature_t n_features_ = 5; bst_group_t n_groups_{3}; - auto GenerateHessian(Context const* ctx, bst_row_t n_samples) const { + auto GenerateHessian(Context const* ctx, bst_idx_t n_samples) const { HostDeviceVector hessian; auto& h_hess = hessian.HostVector(); h_hess = GenerateRandomWeights(n_samples); @@ -844,7 +844,7 @@ class DeviceSketchWithHessianTest protected: Context ctx_ = MakeCUDACtx(0); - void TestLTR(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins, + void TestLTR(Context const* ctx, bst_idx_t n_samples, bst_bin_t n_bins, std::size_t n_elements) const { auto x = GenerateRandom(n_samples, n_features_); @@ -897,7 +897,7 @@ class DeviceSketchWithHessianTest } } - void TestRegression(Context const* ctx, bst_row_t n_samples, bst_bin_t n_bins, + void TestRegression(Context const* ctx, bst_idx_t n_samples, bst_bin_t n_bins, std::size_t n_elements) const { auto x = GenerateRandom(n_samples, n_features_); auto p_fmat = GetDMatrixFromData(x, n_samples, n_features_); @@ -910,9 +910,9 @@ class DeviceSketchWithHessianTest }; auto MakeParamsForTest() { - std::vector sizes = {1, 2, 256, 512, 1000, 1500}; + std::vector sizes = {1, 2, 256, 512, 1000, 1500}; std::vector bin_sizes = {2, 16, 256, 512}; - std::vector> configs; + std::vector> configs; for (auto n_samples : sizes) { for (auto n_bins : bin_sizes) { configs.emplace_back(true, n_samples, n_bins); diff --git a/tests/cpp/common/test_io.cc b/tests/cpp/common/test_io.cc index 4c4d4efe035b..e7f72dc27f71 100644 --- a/tests/cpp/common/test_io.cc +++ b/tests/cpp/common/test_io.cc @@ -1,10 +1,11 @@ /** - * Copyright 2019-2023, XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include #include // for size_t #include // for ofstream +#include // for iota #include "../../../src/common/io.h" #include "../filesystem.h" // dmlc::TemporaryDirectory @@ -14,8 +15,8 @@ namespace xgboost::common { TEST(MemoryFixSizeBuffer, Seek) { size_t constexpr kSize { 64 }; std::vector memory( kSize ); - rabit::utils::MemoryFixSizeBuffer buf(memory.data(), memory.size()); - buf.Seek(rabit::utils::MemoryFixSizeBuffer::kSeekEnd); + MemoryFixSizeBuffer buf(memory.data(), memory.size()); + buf.Seek(MemoryFixSizeBuffer::kSeekEnd); size_t end = buf.Tell(); ASSERT_EQ(end, kSize); } diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 3ee041a339ed..e144bdc45b9f 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -4,10 +4,10 @@ #include #include -#include // for back_inserter +#include // for numeric_limits #include +#include // for iota -#include "../../../src/common/charconv.h" #include "../../../src/common/io.h" #include "../../../src/common/json_utils.h" #include "../../../src/common/threading_utils.h" // for ParallelFor diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index 5f8bab4a3cc4..bf217842b660 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -1,8 +1,11 @@ /** - * Copyright 2021-2023 by XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #include +#include // for equal +#include // for sequence +#include "../../../src/common/cuda_context.cuh" #include "../../../src/common/linalg_op.cuh" #include "../helpers.h" #include "xgboost/context.h" @@ -85,4 +88,23 @@ void TestSlice() { TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); } TEST(Linalg, GPUTensorView) { TestSlice(); } + +TEST(Linalg, GPUIter) { + auto ctx = MakeCUDACtx(1); + auto cuctx = ctx.CUDACtx(); + + dh::device_vector data(2 * 3 * 4); + thrust::sequence(cuctx->CTP(), data.begin(), data.end(), 1.0); + + auto t = MakeTensorView(&ctx, dh::ToSpan(data), 2, 3, 4); + static_assert(!std::is_const_v); + static_assert(!std::is_const_v); + + auto n = std::distance(linalg::tcbegin(t), linalg::tcend(t)); + ASSERT_EQ(n, t.Size()); + ASSERT_FALSE(t.Empty()); + + bool eq = thrust::equal(cuctx->CTP(), data.cbegin(), data.cend(), linalg::tcbegin(t)); + ASSERT_TRUE(eq); +} } // namespace xgboost::linalg diff --git a/tests/cpp/common/test_parameter.cc b/tests/cpp/common/test_parameter.cc index 5e8021a1e7ba..5288366f8831 100644 --- a/tests/cpp/common/test_parameter.cc +++ b/tests/cpp/common/test_parameter.cc @@ -97,4 +97,9 @@ TEST(XGBoostParameter, Update) { ASSERT_NEAR(p.f, 2.71828f, kRtEps); ASSERT_NEAR(p.d, 2.71828, kRtEps); // default } + + // Just in case dmlc's use of global memory has any impact in parameters. + UpdatableParam a, b; + a.UpdateAllowUnknown(xgboost::Args{{"f", "2.71828"}}); + ASSERT_NE(a.f, b.f); } diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index f7170dd14d06..f89a9e36cfc4 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -1,13 +1,20 @@ /** - * Copyright 2020-2023 by XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include "test_quantile.h" #include +#include // for int64_t + +#include "../../../src/collective/allreduce.h" #include "../../../src/common/hist_util.h" #include "../../../src/data/adapter.h" #include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix +#include "../collective/test_worker.h" // for TestDistributedGlobal +#if defined(XGBOOST_USE_FEDERATED) +#include "../plugin/federated/test_worker.h" +#endif // defined(XGBOOST_USE_FEDERATED) #include "xgboost/context.h" namespace xgboost::common { @@ -51,7 +58,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { SimpleLCG lcg; SimpleRealUniformDistribution dist(3, 1000); std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); }); - std::vector column_size(cols, rows); + std::vector column_size(cols, rows); bst_bin_t n_bins = 64; // Generate cuts for distributed environment. @@ -91,6 +98,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { // Generate cuts for single node environment collective::Finalize(); + CHECK_EQ(collective::GetWorldSize(), 1); std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); m->Info().num_row_ = world * rows; @@ -146,7 +154,8 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { template void TestDistributedQuantile(size_t const rows, size_t const cols) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, DoTestDistributedQuantile, rows, cols); + collective::TestDistributedGlobal( + kWorkers, [=] { DoTestDistributedQuantile(rows, cols); }, false); } } // anonymous namespace @@ -193,7 +202,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) { return dmat->SliceCol(world, rank); }()}; - std::vector column_size(cols, 0); + std::vector column_size(cols, 0); auto const slice_size = cols / world; auto const slice_start = slice_size * rank; auto const slice_end = (rank == world - 1) ? cols : slice_start + slice_size; @@ -273,7 +282,8 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) { template void TestColSplitQuantile(size_t rows, size_t cols) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, DoTestColSplitQuantile, rows, cols); + collective::TestDistributedGlobal(kWorkers, + [=] { DoTestColSplitQuantile(rows, cols); }); } } // anonymous namespace @@ -297,6 +307,7 @@ TEST(Quantile, ColSplitSorted) { TestColSplitQuantile(kRows, kCols); } +#if defined(XGBOOST_USE_FEDERATED) namespace { template void DoTestColSplitQuantileSecure() { @@ -319,7 +330,7 @@ void DoTestColSplitQuantileSecure() { return dmat->SliceCol(world, rank); }()}; - std::vector column_size(cols, 0); + std::vector column_size(cols, 0); auto const slice_size = cols / world; auto const slice_start = slice_size * rank; auto const slice_end = (rank == world - 1) ? cols : slice_start + slice_size; @@ -384,17 +395,14 @@ void DoTestColSplitQuantileSecure() { template void TestColSplitQuantileSecure() { auto constexpr kWorkers = 2; - RunWithInMemoryCommunicator(kWorkers, DoTestColSplitQuantileSecure); + collective::TestFederatedGlobal(kWorkers, [&] { DoTestColSplitQuantileSecure(); }); } } // anonymous namespace -TEST(Quantile, ColSplitSecure) { - TestColSplitQuantileSecure(); -} +TEST(Quantile, ColSplitSecure) { TestColSplitQuantileSecure(); } -TEST(Quantile, ColSplitSecureSorted) { - TestColSplitQuantileSecure(); -} +TEST(Quantile, ColSplitSecureSorted) { TestColSplitQuantileSecure(); } +#endif // defined(XGBOOST_USE_FEDERATED) namespace { void TestSameOnAllWorkers() { @@ -424,43 +432,56 @@ void TestSameOnAllWorkers() { cut_ptrs(cuts.Ptrs().size() * world, 0); std::vector cut_min_values(cuts.MinValues().size() * world, 0); - size_t value_size = cuts.Values().size(); - collective::Allreduce(&value_size, 1); - size_t ptr_size = cuts.Ptrs().size(); - collective::Allreduce(&ptr_size, 1); - CHECK_EQ(ptr_size, kCols + 1); - size_t min_value_size = cuts.MinValues().size(); - collective::Allreduce(&min_value_size, 1); - CHECK_EQ(min_value_size, kCols); - - size_t value_offset = value_size * rank; - std::copy(cuts.Values().begin(), cuts.Values().end(), - cut_values.begin() + value_offset); - size_t ptr_offset = ptr_size * rank; - std::copy(cuts.Ptrs().cbegin(), cuts.Ptrs().cend(), - cut_ptrs.begin() + ptr_offset); - size_t min_values_offset = min_value_size * rank; + std::int64_t value_size = cuts.Values().size(); + std::int64_t ptr_size = cuts.Ptrs().size(); + std::int64_t min_value_size = cuts.MinValues().size(); + + auto rc = collective::Success() << [&] { + return collective::Allreduce(&ctx, &value_size, collective::Op::kMax); + } << [&] { + return collective::Allreduce(&ctx, &ptr_size, collective::Op::kMax); + } << [&] { + return collective::Allreduce(&ctx, &min_value_size, collective::Op::kMax); + }; + collective::SafeColl(rc); + ASSERT_EQ(ptr_size, kCols + 1); + ASSERT_EQ(min_value_size, kCols); + + std::size_t value_offset = value_size * rank; + std::copy(cuts.Values().begin(), cuts.Values().end(), cut_values.begin() + value_offset); + std::size_t ptr_offset = ptr_size * rank; + std::copy(cuts.Ptrs().cbegin(), cuts.Ptrs().cend(), cut_ptrs.begin() + ptr_offset); + std::size_t min_values_offset = min_value_size * rank; std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(), cut_min_values.begin() + min_values_offset); - collective::Allreduce(cut_values.data(), cut_values.size()); - collective::Allreduce(cut_ptrs.data(), cut_ptrs.size()); - collective::Allreduce(cut_min_values.data(), cut_min_values.size()); - - for (int32_t i = 0; i < world; i++) { - for (size_t j = 0; j < value_size; ++j) { + rc = std::move(rc) << [&] { + return collective::Allreduce(&ctx, linalg::MakeVec(cut_values.data(), cut_values.size()), + collective::Op::kSum); + } << [&] { + return collective::Allreduce(&ctx, linalg::MakeVec(cut_ptrs.data(), cut_ptrs.size()), + collective::Op::kSum); + } << [&] { + return collective::Allreduce( + &ctx, linalg::MakeVec(cut_min_values.data(), cut_min_values.size()), + collective::Op::kSum); + }; + collective::SafeColl(rc); + + for (std::int32_t i = 0; i < world; i++) { + for (std::int64_t j = 0; j < value_size; ++j) { size_t idx = i * value_size + j; - EXPECT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps); + ASSERT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps); } - for (size_t j = 0; j < ptr_size; ++j) { + for (std::int64_t j = 0; j < ptr_size; ++j) { size_t idx = i * ptr_size + j; EXPECT_EQ(cuts.Ptrs().at(j), cut_ptrs.at(idx)); } - for (size_t j = 0; j < min_value_size; ++j) { + for (std::int64_t j = 0; j < min_value_size; ++j) { size_t idx = i * min_value_size + j; - EXPECT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx)); + ASSERT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx)); } } }); @@ -469,6 +490,6 @@ void TestSameOnAllWorkers() { TEST(Quantile, SameOnAllWorkers) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, TestSameOnAllWorkers); + collective::TestDistributedGlobal(kWorkers, [] { TestSameOnAllWorkers(); }); } } // namespace xgboost::common diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 26bd05524ded..80c9c5c71e5a 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -1,12 +1,13 @@ /** - * Copyright 2020-2023, XGBoost contributors + * Copyright 2020-2024, XGBoost contributors */ #include -#include "../../../src/collective/communicator-inl.cuh" +#include "../../../src/collective/allreduce.h" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/quantile.cuh" #include "../../../src/data/device_adapter.cuh" // CupyAdapter +#include "../collective/test_worker.h" // for BaseMGPUTest #include "../helpers.h" #include "test_quantile.h" @@ -18,16 +19,16 @@ struct IsSorted { } }; } -namespace common { -class MGPUQuantileTest : public BaseMGPUTest {}; +namespace common { +class MGPUQuantileTest : public collective::BaseMGPUTest {}; TEST(GPUQuantile, Basic) { constexpr size_t kRows = 1000, kCols = 100, kBins = 256; HostDeviceVector ft; SketchContainer sketch(ft, kBins, kCols, kRows, FstCU()); dh::caching_device_vector entries; - dh::device_vector cuts_ptr(kCols+1); + dh::device_vector cuts_ptr(kCols+1); thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0); // Push empty sketch.Push(dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0); @@ -36,7 +37,8 @@ TEST(GPUQuantile, Basic) { void TestSketchUnique(float sparsity) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](std::int32_t seed, bst_bin_t n_bins, + MetaInfo const& info) { HostDeviceVector ft; SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU()); @@ -87,11 +89,11 @@ TEST(GPUQuantile, Unique) { // if with_error is true, the test tolerates floating point error void TestQuantileElemRank(DeviceOrd device, Span in, - Span d_columns_ptr, bool with_error = false) { + Span d_columns_ptr, bool with_error = false) { dh::safe_cuda(cudaSetDevice(device.ordinal)); std::vector h_in(in.size()); dh::CopyDeviceSpanToVector(&h_in, in); - std::vector h_columns_ptr(d_columns_ptr.size()); + std::vector h_columns_ptr(d_columns_ptr.size()); dh::CopyDeviceSpanToVector(&h_columns_ptr, d_columns_ptr); for (size_t i = 1; i < d_columns_ptr.size(); ++i) { @@ -121,7 +123,7 @@ void TestQuantileElemRank(DeviceOrd device, Span in, TEST(GPUQuantile, Prune) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { HostDeviceVector ft; SketchContainer sketch(ft, n_bins, kCols, kRows, FstCU()); @@ -164,7 +166,7 @@ TEST(GPUQuantile, MergeEmpty) { std::vector entries_before(sketch_0.Data().size()); dh::CopyDeviceSpanToVector(&entries_before, sketch_0.Data()); - std::vector ptrs_before(sketch_0.ColumnsPtr().size()); + std::vector ptrs_before(sketch_0.ColumnsPtr().size()); dh::CopyDeviceSpanToVector(&ptrs_before, sketch_0.ColumnsPtr()); thrust::device_vector columns_ptr(kCols + 1); // Merge an empty sketch @@ -172,7 +174,7 @@ TEST(GPUQuantile, MergeEmpty) { std::vector entries_after(sketch_0.Data().size()); dh::CopyDeviceSpanToVector(&entries_after, sketch_0.Data()); - std::vector ptrs_after(sketch_0.ColumnsPtr().size()); + std::vector ptrs_after(sketch_0.ColumnsPtr().size()); dh::CopyDeviceSpanToVector(&ptrs_after, sketch_0.ColumnsPtr()); CHECK_EQ(entries_before.size(), entries_after.size()); @@ -190,7 +192,7 @@ TEST(GPUQuantile, MergeEmpty) { TEST(GPUQuantile, MergeBasic) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { HostDeviceVector ft; SketchContainer sketch_0(ft, n_bins, kCols, kRows, FstCU()); HostDeviceVector storage_0; @@ -222,7 +224,7 @@ TEST(GPUQuantile, MergeBasic) { } auto columns_ptr = sketch_0.ColumnsPtr(); - std::vector h_columns_ptr(columns_ptr.size()); + std::vector h_columns_ptr(columns_ptr.size()); dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr); ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge); @@ -260,9 +262,9 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { using Tuple = thrust::tuple; auto it = thrust::make_zip_iterator(tuple_it); thrust::transform(thrust::device, it, it + data_1.size(), data_1.data(), - [=] __device__(Tuple const &tuple) { + [=] XGBOOST_DEVICE(Tuple const& tuple) { auto i = thrust::get<0>(tuple); - if (thrust::get<0>(tuple) % 2 == 0) { + if (i % 2 == 0) { return 0.0f; } else { return thrust::get<1>(tuple); @@ -278,7 +280,7 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { TestQuantileElemRank(FstCU(), sketch_0.Data(), sketch_0.ColumnsPtr()); auto columns_ptr = sketch_0.ColumnsPtr(); - std::vector h_columns_ptr(columns_ptr.size()); + std::vector h_columns_ptr(columns_ptr.size()); dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr); ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge); @@ -306,7 +308,7 @@ TEST(GPUQuantile, MergeDuplicated) { TEST(GPUQuantile, MultiMerge) { constexpr size_t kRows = 20, kCols = 1; int32_t world = 2; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { // Set up single node version HostDeviceVector ft; SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, FstCU()); @@ -368,16 +370,18 @@ namespace { void TestAllReduceBasic() { auto const world = collective::GetWorldSize(); constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { auto const device = DeviceOrd::CUDA(GPUIDX); auto ctx = MakeCUDACtx(device.ordinal); - // Set up single node version; + /** + * Set up single node version. + */ HostDeviceVector ft({}, device); SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, device); - size_t intermediate_num_cuts = std::min( - kRows * world, static_cast(n_bins * WQSketch::kFactor)); + size_t intermediate_num_cuts = + std::min(kRows * world, static_cast(n_bins * WQSketch::kFactor)); std::vector containers; for (auto rank = 0; rank < world; ++rank) { HostDeviceVector storage({}, device); @@ -388,21 +392,22 @@ void TestAllReduceBasic() { data::CupyAdapter adapter(interface_str); HostDeviceVector ft({}, device); containers.emplace_back(ft, n_bins, kCols, kRows, device); - AdapterDeviceSketch(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), + AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &containers.back()); } - for (auto &sketch : containers) { + for (auto& sketch : containers) { sketch.Prune(intermediate_num_cuts); sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data()); sketch_on_single_node.FixError(); } sketch_on_single_node.Unique(); - TestQuantileElemRank(device, sketch_on_single_node.Data(), - sketch_on_single_node.ColumnsPtr(), true); + TestQuantileElemRank(device, sketch_on_single_node.Data(), sketch_on_single_node.ColumnsPtr(), + true); - // Set up distributed version. We rely on using rank as seed to generate - // the exact same copy of data. + /** + * Set up distributed version. We rely on using rank as seed to generate + * the exact same copy of data. + */ auto rank = collective::GetRank(); SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, device); HostDeviceVector storage({}, device); @@ -411,22 +416,23 @@ void TestAllReduceBasic() { .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - AdapterDeviceSketch(adapter.Value(), n_bins, info, - std::numeric_limits::quiet_NaN(), + AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_distributed); + if (world == 1) { + auto n_samples_global = kRows * world; + intermediate_num_cuts = + std::min(n_samples_global, static_cast(n_bins * SketchContainer::kFactor)); + sketch_distributed.Prune(intermediate_num_cuts); + } sketch_distributed.AllReduce(&ctx, false); sketch_distributed.Unique(); - ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), - sketch_on_single_node.ColumnsPtr().size()); - ASSERT_EQ(sketch_distributed.Data().size(), - sketch_on_single_node.Data().size()); + ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), sketch_on_single_node.ColumnsPtr().size()); + ASSERT_EQ(sketch_distributed.Data().size(), sketch_on_single_node.Data().size()); - TestQuantileElemRank(device, sketch_distributed.Data(), - sketch_distributed.ColumnsPtr(), true); + TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true); - std::vector single_node_data( - sketch_on_single_node.Data().size()); + std::vector single_node_data(sketch_on_single_node.Data().size()); dh::CopyDeviceSpanToVector(&single_node_data, sketch_on_single_node.Data()); std::vector distributed_data(sketch_distributed.Data().size()); @@ -444,7 +450,8 @@ void TestAllReduceBasic() { } // anonymous namespace TEST_F(MGPUQuantileTest, AllReduceBasic) { - DoTest(TestAllReduceBasic); + this->DoTest([] { TestAllReduceBasic(); }, true); + this->DoTest([] { TestAllReduceBasic(); }, false); } namespace { @@ -490,7 +497,8 @@ void TestColumnSplit(DMatrix* dmat) { TEST_F(MGPUQuantileTest, ColumnSplitBasic) { std::size_t constexpr kRows = 1000, kCols = 100; auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); - DoTest(TestColumnSplit, dmat.get()); + this->DoTest([&] { TestColumnSplit(dmat.get()); }, true); + this->DoTest([&] { TestColumnSplit(dmat.get()); }, false); } TEST_F(MGPUQuantileTest, ColumnSplitCategorical) { @@ -507,15 +515,15 @@ TEST_F(MGPUQuantileTest, ColumnSplitCategorical) { .Type(ft) .MaxCategory(13) .GenerateDMatrix(); - DoTest(TestColumnSplit, dmat.get()); + this->DoTest([&] { TestColumnSplit(dmat.get()); }, true); + this->DoTest([&] { TestColumnSplit(dmat.get()); }, false); } namespace { void TestSameOnAllWorkers() { auto world = collective::GetWorldSize(); constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, - MetaInfo const &info) { + RunWithSeedsAndBins(kRows, [=](std::int32_t seed, bst_bin_t n_bins, MetaInfo const& info) { auto const rank = collective::GetRank(); auto const device = DeviceOrd::CUDA(GPUIDX); Context ctx = MakeCUDACtx(device.ordinal); @@ -536,7 +544,8 @@ void TestSameOnAllWorkers() { // Test for all workers having the same sketch. size_t n_data = sketch_distributed.Data().size(); - collective::Allreduce(&n_data, 1); + auto rc = collective::Allreduce(&ctx, linalg::MakeVec(&n_data, 1), collective::Op::kMax); + SafeColl(rc); ASSERT_EQ(n_data, sketch_distributed.Data().size()); size_t size_as_float = sketch_distributed.Data().size_bytes() / sizeof(float); @@ -549,9 +558,10 @@ void TestSameOnAllWorkers() { thrust::copy(thrust::device, local_data.data(), local_data.data() + local_data.size(), all_workers.begin() + local_data.size() * rank); - collective::AllReduce(device.ordinal, all_workers.data().get(), - all_workers.size()); - collective::Synchronize(device.ordinal); + rc = collective::Allreduce( + &ctx, linalg::MakeVec(all_workers.data().get(), all_workers.size(), ctx.Device()), + collective::Op::kSum); + SafeColl(rc); auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float); std::vector h_base_line(base_line.size()); @@ -573,7 +583,8 @@ void TestSameOnAllWorkers() { } // anonymous namespace TEST_F(MGPUQuantileTest, SameOnAllWorkers) { - DoTest(TestSameOnAllWorkers); + this->DoTest([] { TestSameOnAllWorkers(); }, true); + this->DoTest([] { TestSameOnAllWorkers(); }, false); } TEST(GPUQuantile, Push) { diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h index d34c5e0e4e34..38ace76c4d13 100644 --- a/tests/cpp/common/test_quantile.h +++ b/tests/cpp/common/test_quantile.h @@ -1,21 +1,22 @@ +/** + * Copyright 2020-2024, XGBoost Contributors + */ #ifndef XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_ #define XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_ #include -#include #include #include "../helpers.h" -namespace xgboost { -namespace common { +namespace xgboost::common { template void RunWithSeedsAndBins(size_t rows, Fn fn) { std::vector seeds(2); SimpleLCG lcg; SimpleRealUniformDistribution dist(3, 1000); std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); }); - std::vector bins(2); + std::vector bins(2); for (size_t i = 0; i < bins.size() - 1; ++i) { bins[i] = i * 35 + 2; } @@ -36,7 +37,6 @@ template void RunWithSeedsAndBins(size_t rows, Fn fn) { } } } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // XGBOOST_TESTS_CPP_COMMON_TEST_QUANTILE_H_ diff --git a/tests/cpp/common/test_span.cu b/tests/cpp/common/test_span.cu index 85c952340659..9c2bdc65cd34 100644 --- a/tests/cpp/common/test_span.cu +++ b/tests/cpp/common/test_span.cu @@ -1,14 +1,15 @@ -/*! - * Copyright 2018 XGBoost contributors +/** + * Copyright 2018-2024, XGBoost contributors */ #include - -#include #include #include +#include +#include + +#include // for iota #include "../../../src/common/device_helpers.cuh" -#include #include "test_span.h" namespace xgboost { diff --git a/tests/cpp/common/test_threadpool.cc b/tests/cpp/common/test_threadpool.cc new file mode 100644 index 000000000000..bd54a9dedbe2 --- /dev/null +++ b/tests/cpp/common/test_threadpool.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2024, XGBoost Contributors + */ +#include + +#include // for size_t +#include // for int32_t +#include // for future +#include // for sleep_for, thread + +#include "../../../src/common/threadpool.h" + +namespace xgboost::common { +TEST(ThreadPool, Basic) { + std::int32_t n_threads = std::thread::hardware_concurrency(); + ThreadPool pool{n_threads}; + { + auto fut = pool.Submit([] { return 3; }); + ASSERT_EQ(fut.get(), 3); + } + { + auto fut = pool.Submit([] { return std::string{"ok"}; }); + ASSERT_EQ(fut.get(), "ok"); + } + { + std::vector> futures; + for (std::size_t i = 0; i < static_cast(n_threads) * 16; ++i) { + futures.emplace_back(pool.Submit([=] { + std::this_thread::sleep_for(std::chrono::milliseconds{10}); + return i; + })); + } + for (std::size_t i = 0; i < futures.size(); ++i) { + ASSERT_EQ(futures[i].get(), i); + } + } + { + std::vector> futures; + for (std::size_t i = 0; i < static_cast(n_threads) * 16; ++i) { + futures.emplace_back(pool.Submit([=] { + return i; + })); + } + for (std::size_t i = 0; i < futures.size(); ++i) { + ASSERT_EQ(futures[i].get(), i); + } + } +} +} // namespace xgboost::common diff --git a/tests/cpp/common/test_transform_range.cc b/tests/cpp/common/test_transform_range.cc index 24d0267b6154..4fc06f63907e 100644 --- a/tests/cpp/common/test_transform_range.cc +++ b/tests/cpp/common/test_transform_range.cc @@ -1,11 +1,12 @@ /** - * Copyright 2018-2023 by XGBoost Contributors + * Copyright 2018-2024, XGBoost Contributors */ #include #include -#include #include +#include +#include // for iota #include #include "../../../src/common/transform.h" diff --git a/tests/cpp/data/test_adapter.cc b/tests/cpp/data/test_adapter.cc index fa3ed61f6808..f34cfceed2f3 100644 --- a/tests/cpp/data/test_adapter.cc +++ b/tests/cpp/data/test_adapter.cc @@ -36,7 +36,7 @@ TEST(Adapter, CSRAdapter) { } TEST(Adapter, CSRArrayAdapter) { - HostDeviceVector indptr; + HostDeviceVector indptr; HostDeviceVector values; HostDeviceVector indices; size_t n_features = 100, n_samples = 10; @@ -155,7 +155,7 @@ TEST(Adapter, IteratorAdapter) { ASSERT_EQ(data->Info().num_row_, kRows); int num_batch = 0; for (auto const& batch : data->GetBatches()) { - ASSERT_EQ(batch.offset.HostVector(), std::vector({0, 2, 4, 5, 5, 7, 9, 10, 10})); + ASSERT_EQ(batch.offset.HostVector(), std::vector({0, 2, 4, 5, 5, 7, 9, 10, 10})); ++num_batch; } ASSERT_EQ(num_batch, 1); diff --git a/tests/cpp/data/test_array_interface.cu b/tests/cpp/data/test_array_interface.cu index 00b996fb9ffb..be8160c8a493 100644 --- a/tests/cpp/data/test_array_interface.cu +++ b/tests/cpp/data/test_array_interface.cu @@ -1,10 +1,11 @@ /** - * Copyright 2021-2023, XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #include #include -#include "../helpers.h" + #include "../../../src/data/array_interface.h" +#include "../helpers.h" namespace xgboost { diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index 99cd72cc09a0..f9e34790d4a3 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -13,7 +13,7 @@ namespace xgboost { TEST(SparsePage, PushCSC) { - std::vector offset {0}; + std::vector offset {0}; std::vector data; SparsePage batch; batch.offset.HostVector() = offset; diff --git a/tests/cpp/data/test_device_adapter.cu b/tests/cpp/data/test_device_adapter.cu index 2190dbe5bceb..61cc9463c228 100644 --- a/tests/cpp/data/test_device_adapter.cu +++ b/tests/cpp/data/test_device_adapter.cu @@ -62,7 +62,7 @@ TEST(DeviceAdapter, GetRowCounts) { .Device(ctx.Device()) .GenerateArrayInterface(&storage); auto adapter = CupyAdapter{str_arr}; - HostDeviceVector offset(adapter.NumRows() + 1, 0); + HostDeviceVector offset(adapter.NumRows() + 1, 0); offset.SetDevice(ctx.Device()); auto rstride = GetRowCounts(adapter.Value(), offset.DeviceSpan(), ctx.Device(), std::numeric_limits::quiet_NaN()); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 67c5b39a424f..837ca7768d34 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -1,15 +1,18 @@ -// Copyright 2016-2021 by Contributors +/** + * Copyright 2016-2024, XGBoost contributors + */ #include "test_metainfo.h" #include +#include #include #include #include -#include "../../../src/common/version.h" -#include "../filesystem.h" // dmlc::TemporaryDirectory -#include "../helpers.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal +#include "../filesystem.h" // dmlc::TemporaryDirectory +#include "../helpers.h" // for GMockTHrow #include "xgboost/base.h" namespace xgboost { @@ -20,23 +23,22 @@ TEST(MetaInfo, GetSet) { double double2[2] = {1.0, 2.0}; EXPECT_EQ(info.labels.Size(), 0); - info.SetInfo(ctx, "label", double2, xgboost::DataType::kFloat32, 2); + info.SetInfo(ctx, "label", Make1dInterfaceTest(double2, 2)); EXPECT_EQ(info.labels.Size(), 2); float float2[2] = {1.0f, 2.0f}; - EXPECT_EQ(info.GetWeight(1), 1.0f) - << "When no weights are given, was expecting default value 1"; - info.SetInfo(ctx, "weight", float2, xgboost::DataType::kFloat32, 2); + EXPECT_EQ(info.GetWeight(1), 1.0f) << "When no weights are given, was expecting default value 1"; + info.SetInfo(ctx, "weight", Make1dInterfaceTest(float2, 2)); EXPECT_EQ(info.GetWeight(1), 2.0f); uint32_t uint32_t2[2] = {1U, 2U}; EXPECT_EQ(info.base_margin_.Size(), 0); - info.SetInfo(ctx, "base_margin", uint32_t2, xgboost::DataType::kUInt32, 2); + info.SetInfo(ctx, "base_margin", Make1dInterfaceTest(uint32_t2, 2)); EXPECT_EQ(info.base_margin_.Size(), 2); uint64_t uint64_t2[2] = {1U, 2U}; EXPECT_EQ(info.group_ptr_.size(), 0); - info.SetInfo(ctx, "group", uint64_t2, xgboost::DataType::kUInt64, 2); + info.SetInfo(ctx, "group", Make1dInterfaceTest(uint64_t2, 2)); ASSERT_EQ(info.group_ptr_.size(), 3); EXPECT_EQ(info.group_ptr_[2], 3); @@ -46,6 +48,8 @@ TEST(MetaInfo, GetSet) { TEST(MetaInfo, GetSetFeature) { xgboost::MetaInfo info; + ASSERT_THAT([&] { info.SetFeatureInfo("", nullptr, 0); }, + GMockThrow("Unknown feature info name")); EXPECT_THROW(info.SetFeatureInfo("", nullptr, 0), dmlc::Error); EXPECT_THROW(info.SetFeatureInfo("foo", nullptr, 0), dmlc::Error); EXPECT_NO_THROW(info.SetFeatureInfo("feature_name", nullptr, 0)); @@ -86,7 +90,8 @@ void VerifyGetSetFeatureColumnSplit() { std::transform(types.cbegin(), types.cend(), c_types.begin(), [](auto const &str) { return str.c_str(); }); info.num_col_ = kCols; - EXPECT_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()), dmlc::Error); + ASSERT_THAT([&] { info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()); }, + GMockThrow("Length of feature_type must be equal to number of columns")); info.num_col_ = kCols * world_size; EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size())); std::vector expected_type_names{u8"float", u8"c", u8"float", @@ -103,7 +108,8 @@ void VerifyGetSetFeatureColumnSplit() { std::transform(names.cbegin(), names.cend(), c_names.begin(), [](auto const &str) { return str.c_str(); }); info.num_col_ = kCols; - EXPECT_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()), dmlc::Error); + ASSERT_THAT([&] { info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()); }, + GMockThrow("Length of feature_name must be equal to number of columns")); info.num_col_ = kCols * world_size; EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size())); std::vector expected_names{u8"0.feature0", u8"0.feature1", u8"1.feature0", @@ -113,8 +119,8 @@ void VerifyGetSetFeatureColumnSplit() { } // anonymous namespace TEST(MetaInfo, GetSetFeatureColumnSplit) { - auto constexpr kWorldSize{3}; - RunWithInMemoryCommunicator(kWorldSize, VerifyGetSetFeatureColumnSplit); + auto constexpr kWorkers{3}; + collective::TestDistributedGlobal(kWorkers, VerifyGetSetFeatureColumnSplit); } TEST(MetaInfo, SaveLoadBinary) { @@ -128,9 +134,9 @@ TEST(MetaInfo, SaveLoadBinary) { }; std::vector values (kRows); std::generate(values.begin(), values.end(), generator); - info.SetInfo(ctx, "label", values.data(), xgboost::DataType::kFloat32, kRows); - info.SetInfo(ctx, "weight", values.data(), xgboost::DataType::kFloat32, kRows); - info.SetInfo(ctx, "base_margin", values.data(), xgboost::DataType::kFloat32, kRows); + info.SetInfo(ctx, "label", Make1dInterfaceTest(values.data(), kRows)); + info.SetInfo(ctx, "weight", Make1dInterfaceTest(values.data(), kRows)); + info.SetInfo(ctx, "base_margin", Make1dInterfaceTest(values.data(), kRows)); info.num_row_ = kRows; info.num_col_ = kCols; @@ -224,7 +230,7 @@ TEST(MetaInfo, LoadQid) { const std::vector expected_group_ptr{0, 4, 8, 12}; CHECK(info.group_ptr_ == expected_group_ptr); - const std::vector expected_offset{ + const std::vector expected_offset{ 0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60 }; const std::vector expected_data{ @@ -264,7 +270,7 @@ TEST(MetaInfo, CPUQid) { qid[i] = i; } - info.SetInfo(ctx, "qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_); + info.SetInfo(ctx, "qid", Make1dInterfaceTest(qid.data(), info.num_row_)); ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1); ASSERT_EQ(info.group_ptr_.front(), 0); ASSERT_EQ(info.group_ptr_.back(), info.num_row_); @@ -281,14 +287,12 @@ TEST(MetaInfo, Validate) { info.num_col_ = 3; std::vector groups (11); Context ctx; - info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, 11); + info.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), groups.size())); EXPECT_THROW(info.Validate(FstCU()), dmlc::Error); std::vector labels(info.num_row_ + 1); EXPECT_THROW( - { - info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); - }, + { info.SetInfo(ctx, "label", Make1dInterfaceTest(labels.data(), info.num_row_ + 1)); }, dmlc::Error); // Make overflow data, which can happen when users pass group structure as int @@ -298,13 +302,13 @@ TEST(MetaInfo, Validate) { groups.push_back(1562500); } groups.push_back(static_cast(-1)); - EXPECT_THROW(info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()), + EXPECT_THROW(info.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), groups.size())), dmlc::Error); #if defined(XGBOOST_USE_CUDA) info.group_ptr_.clear(); labels.resize(info.num_row_); - info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_); + info.SetInfo(ctx, "label", Make1dInterfaceTest(labels.data(), info.num_row_)); info.labels.SetDevice(FstCU()); EXPECT_THROW(info.Validate(DeviceOrd::CUDA(1)), dmlc::Error); @@ -333,8 +337,8 @@ TEST(MetaInfo, HostExtend) { for (size_t g = 0; g < kRows / per_group; ++g) { groups.emplace_back(per_group); } - lhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()); - rhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()); + lhs.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), groups.size())); + rhs.SetInfo(ctx, "group", Make1dInterfaceTest(groups.data(), groups.size())); lhs.Extend(rhs, true, true); ASSERT_EQ(lhs.num_row_, kRows * 2); diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index fa4165796bd9..ea6eedbb2e7b 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -9,6 +9,7 @@ #include "../../../src/data/adapter.h" // ArrayAdapter #include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" // RandomDataGenerator,CreateSimpleTestData #include "xgboost/base.h" @@ -223,7 +224,7 @@ TEST(SimpleDMatrix, FromFile) { auto batch = page.GetView(); EXPECT_EQ(batch.Size(), kExpectedNumRow); EXPECT_EQ(page.offset.HostVector(), - std::vector({0, 3, 6, 9, 12, 15, 15})); + std::vector({0, 3, 6, 9, 12, 15, 15})); EXPECT_EQ(page.base_rowid, 0); for (auto i = 0ull; i < batch.Size() - 1; i++) { @@ -444,5 +445,5 @@ void VerifyColumnSplit() { TEST(SimpleDMatrix, ColumnSplit) { auto constexpr kWorldSize{3}; - RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit); + collective::TestDistributedGlobal(kWorldSize, VerifyColumnSplit); } diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index dac1f1cf7458..dcb89b97189c 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -171,7 +171,7 @@ TEST(GBTree, ChoosePredictor) { } TEST(GBTree, ChooseTreeMethod) { - bst_row_t n_samples{128}; + bst_idx_t n_samples{128}; bst_feature_t n_features{64}; auto Xy = RandomDataGenerator{n_samples, n_features, 0.5f}.GenerateDMatrix(true); @@ -408,7 +408,7 @@ class Dart : public testing::TestWithParam { for (size_t i = 0; i < kRows; ++i) { labels[i] = i % 2; } - p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows); + p_mat->SetInfo("label", Make1dInterfaceTest(labels.data(), kRows)); auto learner = std::unique_ptr(Learner::Create({p_mat})); learner->SetParam("booster", "dart"); diff --git a/tests/cpp/gbm/test_gbtree.cu b/tests/cpp/gbm/test_gbtree.cu index f308e3b3ea36..227e07ffd3fd 100644 --- a/tests/cpp/gbm/test_gbtree.cu +++ b/tests/cpp/gbm/test_gbtree.cu @@ -18,7 +18,7 @@ namespace xgboost { void TestInplaceFallback(Context const* ctx) { // prepare data - bst_row_t n_samples{1024}; + bst_idx_t n_samples{1024}; bst_feature_t n_features{32}; HostDeviceVector X_storage; // use a different device than the learner diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 6ce362f46763..a4761063688d 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -12,9 +12,9 @@ #include #include -#include #include +#include "../../src/collective/communicator-inl.h" // for GetRank #include "../../src/data/adapter.h" #include "../../src/data/iterative_dmatrix.h" #include "../../src/data/simple_dmatrix.h" @@ -216,7 +216,7 @@ SimpleLCG::StateType SimpleLCG::Max() const { return max(); } static_assert(SimpleLCG::max() - SimpleLCG::min()); void RandomDataGenerator::GenerateLabels(std::shared_ptr p_fmat) const { - RandomDataGenerator{static_cast(p_fmat->Info().num_row_), this->n_targets_, 0.0f}.GenerateDense( + RandomDataGenerator{static_cast(p_fmat->Info().num_row_), this->n_targets_, 0.0f}.GenerateDense( p_fmat->Info().labels.Data()); CHECK_EQ(p_fmat->Info().labels.Size(), this->rows_ * this->n_targets_); p_fmat->Info().labels.Reshape(this->rows_, this->n_targets_); @@ -334,7 +334,7 @@ std::string RandomDataGenerator::GenerateColumnarArrayInterface( } void RandomDataGenerator::GenerateCSR( - HostDeviceVector* value, HostDeviceVector* row_ptr, + HostDeviceVector* value, HostDeviceVector* row_ptr, HostDeviceVector* columns) const { auto& h_value = value->HostVector(); auto& h_rptr = row_ptr->HostVector(); @@ -381,7 +381,7 @@ void RandomDataGenerator::GenerateCSR( [[nodiscard]] std::shared_ptr RandomDataGenerator::GenerateDMatrix( bool with_label, bool float_label, size_t classes, DataSplitMode data_split_mode) const { HostDeviceVector data; - HostDeviceVector rptrs; + HostDeviceVector rptrs; HostDeviceVector columns; this->GenerateCSR(&data, &rptrs, &columns); data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), rows_, @@ -447,7 +447,7 @@ void RandomDataGenerator::GenerateCSR( // Loop over the batches and count the number of pages std::size_t batch_count = 0; - bst_row_t row_count = 0; + bst_idx_t row_count = 0; for (const auto& batch : dmat->GetBatches()) { batch_count++; row_count += batch.Size(); @@ -458,7 +458,7 @@ void RandomDataGenerator::GenerateCSR( EXPECT_EQ(row_count, dmat->Info().num_row_); if (with_label) { - RandomDataGenerator{static_cast(dmat->Info().num_row_), this->n_targets_, 0.0f}.GenerateDense( + RandomDataGenerator{static_cast(dmat->Info().num_row_), this->n_targets_, 0.0f}.GenerateDense( dmat->Info().labels.Data()); CHECK_EQ(dmat->Info().labels.Size(), this->rows_ * this->n_targets_); dmat->Info().labels.Reshape(this->rows_, this->n_targets_); @@ -488,7 +488,7 @@ int CudaArrayIterForTest::Next() { } #endif // !defined(XGBOOST_USE_CUDA) -NumpyArrayIterForTest::NumpyArrayIterForTest(float sparsity, size_t rows, size_t cols, +NumpyArrayIterForTest::NumpyArrayIterForTest(float sparsity, bst_idx_t rows, size_t cols, size_t batches) : ArrayIterForTest{sparsity, rows, cols, batches} { rng_->Device(DeviceOrd::CPU()); @@ -515,7 +515,7 @@ std::shared_ptr GetDMatrixFromData(const std::vector& x, std::si return p_fmat; } -std::unique_ptr CreateSparsePageDMatrix(bst_row_t n_samples, bst_feature_t n_features, +std::unique_ptr CreateSparsePageDMatrix(bst_idx_t n_samples, bst_feature_t n_features, size_t n_batches, std::string prefix) { CHECK_GE(n_samples, n_batches); NumpyArrayIterForTest iter(0, n_samples, n_features, n_batches); @@ -662,7 +662,7 @@ std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, return gbm; } -ArrayIterForTest::ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches) +ArrayIterForTest::ArrayIterForTest(float sparsity, bst_idx_t rows, size_t cols, size_t batches) : rows_{rows}, cols_{cols}, n_batches_{batches} { XGProxyDMatrixCreate(&proxy_); rng_ = std::make_unique(rows_, cols_, sparsity); diff --git a/tests/cpp/helpers.cu b/tests/cpp/helpers.cu index db94da27a9b9..f756289538ab 100644 --- a/tests/cpp/helpers.cu +++ b/tests/cpp/helpers.cu @@ -1,8 +1,11 @@ +/** + * Copyright 2020-2024, XGBoost contributors + */ #include -#include "helpers.h" #include "../../src/data/device_adapter.cuh" #include "../../src/data/iterative_dmatrix.h" +#include "helpers.h" namespace xgboost { diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index d603685eb073..cb8852e1b79f 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -1,8 +1,9 @@ /** - * Copyright 2016-2024 by XGBoost contributors + * Copyright 2016-2024, XGBoost contributors */ #pragma once +#include #include #include #include @@ -12,21 +13,22 @@ #include // for LearnerModelParam #include // for Configurable -#include // std::int32_t +#include // std::int32_t #include -#include -#include #include #include -#include #include -#include "../../src/collective/communicator-inl.h" -#include "../../src/common/common.h" -#include "../../src/common/threading_utils.h" -#include "../../src/data/array_interface.h" +#if defined(__CUDACC__) +#include "../../src/collective/communicator-inl.h" // for GetRank +#include "../../src/common/common.h" // for AllVisibleGPUs +#endif // defined(__CUDACC__) + #include "filesystem.h" // dmlc::TemporaryDirectory #include "xgboost/linalg.h" +#if !defined(_OPENMP) +#include +#endif #if defined(__CUDACC__) #define DeclareUnifiedTest(name) GPU ## name @@ -222,7 +224,7 @@ Json GetArrayInterface(HostDeviceVector const* storage, size_t rows, size_t c // Generate in-memory random data without using DMatrix. class RandomDataGenerator { - bst_row_t rows_; + bst_idx_t rows_; size_t cols_; float sparsity_; @@ -245,7 +247,7 @@ class RandomDataGenerator { void GenerateLabels(std::shared_ptr p_fmat) const; public: - RandomDataGenerator(bst_row_t rows, size_t cols, float sparsity) + RandomDataGenerator(bst_idx_t rows, size_t cols, float sparsity) : rows_{rows}, cols_{cols}, sparsity_{sparsity}, lcg_{seed_} {} RandomDataGenerator& Lower(float v) { @@ -307,7 +309,7 @@ class RandomDataGenerator { std::string GenerateColumnarArrayInterface(std::vector>* data) const; - void GenerateCSR(HostDeviceVector* value, HostDeviceVector* row_ptr, + void GenerateCSR(HostDeviceVector* value, HostDeviceVector* row_ptr, HostDeviceVector* columns) const; [[nodiscard]] std::shared_ptr GenerateDMatrix( @@ -332,7 +334,7 @@ inline std::vector GenerateRandomCategoricalSingleColumn(int n, size_t nu std::vector x(n); std::mt19937 rng(0); std::uniform_int_distribution dist(0, num_categories - 1); - std::generate(x.begin(), x.end(), [&]() { return dist(rng); }); + std::generate(x.begin(), x.end(), [&]() { return static_cast(dist(rng)); }); // Make sure each category is present for (size_t i = 0; i < num_categories; i++) { x[i] = static_cast(i); @@ -353,7 +355,7 @@ std::shared_ptr GetDMatrixFromData(const std::vector& x, std::si * * \return A Sparse DMatrix with n_batches. */ -std::unique_ptr CreateSparsePageDMatrix(bst_row_t n_samples, bst_feature_t n_features, +std::unique_ptr CreateSparsePageDMatrix(bst_idx_t n_samples, bst_feature_t n_features, size_t n_batches, std::string prefix = "cache"); /** @@ -412,12 +414,12 @@ inline HostDeviceVector GenerateRandomGradients(const size_t n_row return gpair; } -inline linalg::Matrix GenerateRandomGradients(Context const* ctx, bst_row_t n_rows, +inline linalg::Matrix GenerateRandomGradients(Context const* ctx, bst_idx_t n_rows, bst_target_t n_targets, float lower = 0.0f, float upper = 1.0f) { auto g = GenerateRandomGradients(n_rows * n_targets, lower, upper); - linalg::Matrix gpair({n_rows, static_cast(n_targets)}, ctx->Device()); + linalg::Matrix gpair({n_rows, static_cast(n_targets)}, ctx->Device()); gpair.Data()->Copy(g); return gpair; } @@ -433,12 +435,12 @@ class ArrayIterForTest { std::vector batches_; std::string interface_; - size_t rows_; + bst_idx_t rows_; size_t cols_; size_t n_batches_; public: - size_t static constexpr Rows() { return 1024; } + bst_idx_t static constexpr Rows() { return 1024; } size_t static constexpr Batches() { return 100; } size_t static constexpr Cols() { return 13; } @@ -450,7 +452,7 @@ class ArrayIterForTest { [[nodiscard]] std::size_t Iter() const { return iter_; } auto Proxy() -> decltype(proxy_) { return proxy_; } - explicit ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches); + explicit ArrayIterForTest(float sparsity, bst_idx_t rows, size_t cols, size_t batches); /** * \brief Create iterator with user provided data. */ @@ -469,7 +471,7 @@ class CudaArrayIterForTest : public ArrayIterForTest { class NumpyArrayIterForTest : public ArrayIterForTest { public: - explicit NumpyArrayIterForTest(float sparsity, size_t rows = Rows(), size_t cols = Cols(), + explicit NumpyArrayIterForTest(float sparsity, bst_idx_t rows = Rows(), size_t cols = Cols(), size_t batches = Batches()); explicit NumpyArrayIterForTest(Context const* ctx, HostDeviceVector const& data, std::size_t n_samples, bst_feature_t n_features, @@ -493,6 +495,16 @@ inline int Next(DataIterHandle self) { return static_cast(self)->Next(); } +/** + * @brief Create an array interface for host vector. + */ +template +char const* Make1dInterfaceTest(T const* vec, std::size_t len) { + static thread_local std::string str; + str = linalg::Make1dInterface(vec, len); + return str.c_str(); +} + class RMMAllocator; using RMMAllocatorPtr = std::unique_ptr; RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv); @@ -510,93 +522,9 @@ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint inline std::int32_t AllThreadsForTest() { return Context{}.Threads(); } -template -void RunWithInMemoryCommunicator(int32_t world_size, Function&& function, Args&&... args) { - auto run = [&](auto rank) { - Json config{JsonObject()}; - if constexpr (use_nccl) { - config["xgboost_communicator"] = String("in-memory-nccl"); - } else { - config["xgboost_communicator"] = String("in-memory"); - } - config["in_memory_world_size"] = world_size; - config["in_memory_rank"] = rank; - xgboost::collective::Init(config); - - std::forward(function)(std::forward(args)...); - - xgboost::collective::Finalize(); - }; -#if defined(_OPENMP) - common::ParallelFor(world_size, world_size, run); -#else - std::vector threads; - for (auto rank = 0; rank < world_size; rank++) { - threads.emplace_back(run, rank); - } - for (auto& thread : threads) { - thread.join(); - } -#endif -} - -class BaseMGPUTest : public ::testing::Test { - protected: - int world_size_; - bool use_nccl_{false}; - - void SetUp() override { - auto const n_gpus = common::AllVisibleGPUs(); - if (n_gpus <= 1) { - // Use a single GPU to simulate distributed environment. - world_size_ = 3; - // NCCL doesn't like sharing a single GPU, so we use the adapter instead. - use_nccl_ = false; - } else { - // Use multiple GPUs for real. - world_size_ = n_gpus; - use_nccl_ = true; - } - } - - template - void DoTest(Function&& function, Args&&... args) { - if (use_nccl_) { - RunWithInMemoryCommunicator(world_size_, function, args...); - } else { - RunWithInMemoryCommunicator(world_size_, function, args...); - } - } -}; - -class DeclareUnifiedDistributedTest(MetricTest) : public BaseMGPUTest{}; - inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); } -/** - * @brief poor man's gmock for message matching. - * - * @tparam Error The type of expected execption. - * - * @param submsg A substring of the actual error message. - * @param fn The function that throws Error - */ -template -void ExpectThrow(std::string submsg, Fn&& fn) { - try { - fn(); - } catch (Error const& exc) { - auto actual = std::string{exc.what()}; - ASSERT_NE(actual.find(submsg), std::string::npos) - << "Expecting substring `" << submsg << "` from the error message." - << " Got:\n" - << actual << "\n"; - return; - } catch (std::exception const& exc) { - auto actual = exc.what(); - ASSERT_TRUE(false) << "An unexpected type of exception is thrown. what:" << actual; - return; - } - ASSERT_TRUE(false) << "No exception is thrown"; +inline auto GMockThrow(StringView msg) { + return ::testing::ThrowsMessage(::testing::HasSubstr(msg)); } } // namespace xgboost diff --git a/tests/cpp/histogram_helpers.h b/tests/cpp/histogram_helpers.h index 496aa30f3475..8f345484d06b 100644 --- a/tests/cpp/histogram_helpers.h +++ b/tests/cpp/histogram_helpers.h @@ -47,7 +47,7 @@ inline std::unique_ptr BuildEllpackPage(int n_rows, int n_cols, 0.26f, 0.71f, 1.83f}); cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); - bst_row_t row_stride = 0; + bst_idx_t row_stride = 0; const auto &offset_vec = batch.offset.ConstHostVector(); for (size_t i = 1; i < offset_vec.size(); ++i) { row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]); diff --git a/tests/cpp/metric/test_auc.cc b/tests/cpp/metric/test_auc.cc deleted file mode 100644 index eea54fc3204a..000000000000 --- a/tests/cpp/metric/test_auc.cc +++ /dev/null @@ -1,68 +0,0 @@ -#include "test_auc.h" - -#include - -namespace xgboost { -namespace metric { - -TEST(Metric, DeclareUnifiedTest(BinaryAUC)) { VerifyBinaryAUC(); } - -TEST(Metric, DeclareUnifiedTest(MultiClassAUC)) { VerifyMultiClassAUC(); } - -TEST(Metric, DeclareUnifiedTest(RankingAUC)) { VerifyRankingAUC(); } - -TEST(Metric, DeclareUnifiedTest(PRAUC)) { VerifyPRAUC(); } - -TEST(Metric, DeclareUnifiedTest(MultiClassPRAUC)) { VerifyMultiClassPRAUC(); } - -TEST(Metric, DeclareUnifiedTest(RankingPRAUC)) { VerifyRankingPRAUC(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), BinaryAUCRowSplit) { - DoTest(VerifyBinaryAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), BinaryAUCColumnSplit) { - DoTest(VerifyBinaryAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassAUCRowSplit) { - DoTest(VerifyMultiClassAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassAUCColumnSplit) { - DoTest(VerifyMultiClassAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RankingAUCRowSplit) { - DoTest(VerifyRankingAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RankingAUCColumnSplit) { - DoTest(VerifyRankingAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PRAUCRowSplit) { - DoTest(VerifyPRAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PRAUCColumnSplit) { - DoTest(VerifyPRAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassPRAUCRowSplit) { - DoTest(VerifyMultiClassPRAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassPRAUCColumnSplit) { - DoTest(VerifyMultiClassPRAUC, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RankingPRAUCRowSplit) { - DoTest(VerifyRankingPRAUC, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RankingPRAUCColumnSplit) { - DoTest(VerifyRankingPRAUC, DataSplitMode::kCol); -} -} // namespace metric -} // namespace xgboost diff --git a/tests/cpp/metric/test_auc.cu b/tests/cpp/metric/test_auc.cu deleted file mode 100644 index 430ab1d374c1..000000000000 --- a/tests/cpp/metric/test_auc.cu +++ /dev/null @@ -1,5 +0,0 @@ -/*! - * Copyright 2021 XGBoost contributors - */ -// Dummy file to keep the CUDA conditional compile trick. -#include "test_auc.cc" \ No newline at end of file diff --git a/tests/cpp/metric/test_auc.h b/tests/cpp/metric/test_auc.h index cef6d9757d14..dc99ab2e95ef 100644 --- a/tests/cpp/metric/test_auc.h +++ b/tests/cpp/metric/test_auc.h @@ -7,11 +7,9 @@ #include "../helpers.h" -namespace xgboost { -namespace metric { - -inline void VerifyBinaryAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +namespace xgboost::metric { +inline void VerifyBinaryAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr uni_ptr{Metric::Create("auc", &ctx)}; Metric* metric = uni_ptr.get(); ASSERT_STREQ(metric->Name(), "auc"); @@ -53,8 +51,8 @@ inline void VerifyBinaryAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) 0.5, 1e-10); } -inline void VerifyMultiClassAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMultiClassAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr uni_ptr{Metric::Create("auc", &ctx)}; auto metric = uni_ptr.get(); @@ -114,8 +112,8 @@ inline void VerifyMultiClassAUC(DataSplitMode data_split_mode = DataSplitMode::k ASSERT_GT(auc, 0.714); } -inline void VerifyRankingAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyRankingAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("auc", &ctx)}; // single group @@ -148,8 +146,8 @@ inline void VerifyRankingAUC(DataSplitMode data_split_mode = DataSplitMode::kRow 0.769841f, 1e-6); } -inline void VerifyPRAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyPRAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric* metric = xgboost::Metric::Create("aucpr", &ctx); ASSERT_STREQ(metric->Name(), "aucpr"); @@ -185,8 +183,8 @@ inline void VerifyPRAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { delete metric; } -inline void VerifyMultiClassPRAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMultiClassPRAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("aucpr", &ctx)}; @@ -209,8 +207,8 @@ inline void VerifyMultiClassPRAUC(DataSplitMode data_split_mode = DataSplitMode: ASSERT_GT(auc, 0.699); } -inline void VerifyRankingPRAUC(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyRankingPRAUC(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("aucpr", &ctx)}; @@ -245,5 +243,4 @@ inline void VerifyRankingPRAUC(DataSplitMode data_split_mode = DataSplitMode::kR data_split_mode), 0.556021f, 0.001f); } -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/tests/cpp/metric/test_distributed_metric.cc b/tests/cpp/metric/test_distributed_metric.cc new file mode 100644 index 000000000000..843ea5762f4b --- /dev/null +++ b/tests/cpp/metric/test_distributed_metric.cc @@ -0,0 +1,192 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#include +#include // for DeviceOrd +#include // for DataSplitMode + +#include // for min +#include // for int32_t +#include // for function +#include // for string +#include // for thread + +#include "../collective/test_worker.h" // for TestDistributedGlobal +#include "test_auc.h" +#include "test_elementwise_metric.h" +#include "test_multiclass_metric.h" +#include "test_rank_metric.h" +#include "test_survival_metric.h" + +#if defined(XGBOOST_USE_FEDERATED) + +#include "../plugin/federated/test_worker.h" // for TestFederatedGlobal + +#endif // defined(XGBOOST_USE_FEDERATED) + +namespace xgboost::metric { +namespace { +using Verifier = std::function; +struct Param { + bool is_dist; // is distributed + bool is_fed; // is federated learning + DataSplitMode split; // how to split data + Verifier v; // test function + std::string name; // metric name + DeviceOrd device; // device to run +}; + +class TestDistributedMetric : public ::testing::TestWithParam { + protected: + template + void Run(bool is_dist, bool is_fed, DataSplitMode split_mode, Fn fn, DeviceOrd device) { + if (!is_dist) { + fn(split_mode, device); + return; + } + + std::int32_t n_workers{0}; + if (device.IsCUDA()) { + n_workers = common::AllVisibleGPUs(); + } else { + n_workers = std::min(static_cast(std::thread::hardware_concurrency()), 3); + } + auto fn1 = [&]() { + auto r = collective::GetRank(); + if (device.IsCPU()) { + fn(split_mode, DeviceOrd::CPU()); + } else { + fn(split_mode, DeviceOrd::CUDA(r)); + } + }; + if (is_fed) { +#if defined(XGBOOST_USE_FEDERATED) + collective::TestFederatedGlobal(n_workers, fn1); +#endif // defined(XGBOOST_USE_FEDERATED) + } else { + collective::TestDistributedGlobal(n_workers, fn1); + } + } +}; +} // anonymous namespace + +TEST_P(TestDistributedMetric, BinaryAUCRowSplit) { + auto p = GetParam(); + this->Run(p.is_dist, p.is_fed, p.split, p.v, p.device); +} + +constexpr bool UseNCCL() { +#if defined(XGBOOST_USE_NCCL) + return true; +#else + return false; +#endif // defined(XGBOOST_USE_NCCL) +} + +constexpr bool UseCUDA() { +#if defined(XGBOOST_USE_CUDA) + return true; +#else + return false; +#endif // defined(XGBOOST_USE_CUDA) +} + +constexpr bool UseFederated() { +#if defined(XGBOOST_USE_FEDERATED) + return true; +#else + return false; +#endif +} + +auto MakeParamsForTest() { + std::vector cases; + + auto push = [&](std::string name, auto fn) { + for (bool is_federated : {false, true}) { + for (DataSplitMode m : {DataSplitMode::kCol, DataSplitMode::kRow}) { + for (auto d : {DeviceOrd::CPU(), DeviceOrd::CUDA(0)}) { + if (!is_federated && !UseNCCL() && d.IsCUDA()) { + // Federated doesn't use nccl. + continue; + } + if (!UseCUDA() && d.IsCUDA()) { + // skip CUDA tests + continue; + } + if (!UseFederated() && is_federated) { + // skip GRPC tests + continue; + } + + auto p = Param{true, is_federated, m, fn, name, d}; + cases.push_back(p); + if (!is_federated) { + // Add a local test. + p.is_dist = false; + cases.push_back(p); + } + } + } + } + }; + +#define REFLECT_NAME(name) push(#name, Verify##name) + // AUC + REFLECT_NAME(BinaryAUC); + REFLECT_NAME(MultiClassAUC); + REFLECT_NAME(RankingAUC); + REFLECT_NAME(PRAUC); + REFLECT_NAME(MultiClassPRAUC); + REFLECT_NAME(RankingPRAUC); + // Elementwise + REFLECT_NAME(RMSE); + REFLECT_NAME(RMSLE); + REFLECT_NAME(MAE); + REFLECT_NAME(MAPE); + REFLECT_NAME(MPHE); + REFLECT_NAME(LogLoss); + REFLECT_NAME(Error); + REFLECT_NAME(PoissonNegLogLik); + REFLECT_NAME(MultiRMSE); + REFLECT_NAME(Quantile); + // Multi-Class + REFLECT_NAME(MultiClassError); + REFLECT_NAME(MultiClassLogLoss); + // Ranking + REFLECT_NAME(Precision); + REFLECT_NAME(NDCG); + REFLECT_NAME(MAP); + REFLECT_NAME(NDCGExpGain); + // AFT + using namespace xgboost::common; // NOLINT + REFLECT_NAME(AFTNegLogLik); + REFLECT_NAME(IntervalRegressionAccuracy); + +#undef REFLECT_NAME + + return cases; +} + +INSTANTIATE_TEST_SUITE_P( + DistributedMetric, TestDistributedMetric, ::testing::ValuesIn(MakeParamsForTest()), + [](const ::testing::TestParamInfo& info) { + std::string result; + if (info.param.is_dist) { + result += "Dist_"; + } + if (info.param.is_fed) { + result += "Federated_"; + } + if (info.param.split == DataSplitMode::kRow) { + result += "RowSplit"; + } else { + result += "ColSplit"; + } + result += "_"; + result += info.param.device.IsCPU() ? "CPU" : "CUDA"; + result += "_"; + result += info.param.name; + return result; + }); +} // namespace xgboost::metric diff --git a/tests/cpp/metric/test_elementwise_metric.cc b/tests/cpp/metric/test_elementwise_metric.cc deleted file mode 100644 index 11854ce8895b..000000000000 --- a/tests/cpp/metric/test_elementwise_metric.cc +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2018-2023 by XGBoost contributors - */ -#include "test_elementwise_metric.h" - -namespace xgboost::metric { -TEST(Metric, DeclareUnifiedTest(RMSE)) { VerifyRMSE(); } - -TEST(Metric, DeclareUnifiedTest(RMSLE)) { VerifyRMSLE(); } - -TEST(Metric, DeclareUnifiedTest(MAE)) { VerifyMAE(); } - -TEST(Metric, DeclareUnifiedTest(MAPE)) { VerifyMAPE(); } - -TEST(Metric, DeclareUnifiedTest(MPHE)) { VerifyMPHE(); } - -TEST(Metric, DeclareUnifiedTest(LogLoss)) { VerifyLogLoss(); } - -TEST(Metric, DeclareUnifiedTest(Error)) { VerifyError(); } - -TEST(Metric, DeclareUnifiedTest(PoissonNegLogLik)) { VerifyPoissonNegLogLik(); } - -TEST(Metric, DeclareUnifiedTest(MultiRMSE)) { VerifyMultiRMSE(); } - -TEST(Metric, DeclareUnifiedTest(Quantile)) { VerifyQuantile(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RMSERowSplit) { - DoTest(VerifyRMSE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RMSEColumnSplit) { - DoTest(VerifyRMSE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RMSLERowSplit) { - DoTest(VerifyRMSLE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), RMSLEColumnSplit) { - DoTest(VerifyRMSLE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAERowSplit) { - DoTest(VerifyMAE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAEColumnSplit) { - DoTest(VerifyMAE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAPERowSplit) { - DoTest(VerifyMAPE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAPEColumnSplit) { - DoTest(VerifyMAPE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MPHERowSplit) { - DoTest(VerifyMPHE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MPHEColumnSplit) { - DoTest(VerifyMPHE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), LogLossRowSplit) { - DoTest(VerifyLogLoss, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), LogLossColumnSplit) { - DoTest(VerifyLogLoss, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), ErrorRowSplit) { - DoTest(VerifyError, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), ErrorColumnSplit) { - DoTest(VerifyError, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PoissonNegLogLikRowSplit) { - DoTest(VerifyPoissonNegLogLik, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PoissonNegLogLikColumnSplit) { - DoTest(VerifyPoissonNegLogLik, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiRMSERowSplit) { - DoTest(VerifyMultiRMSE, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiRMSEColumnSplit) { - DoTest(VerifyMultiRMSE, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), QuantileRowSplit) { - DoTest(VerifyQuantile, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), QuantileColumnSplit) { - DoTest(VerifyQuantile, DataSplitMode::kCol); -} -} // namespace xgboost::metric diff --git a/tests/cpp/metric/test_elementwise_metric.cu b/tests/cpp/metric/test_elementwise_metric.cu deleted file mode 100644 index c45db8f7ffc5..000000000000 --- a/tests/cpp/metric/test_elementwise_metric.cu +++ /dev/null @@ -1,5 +0,0 @@ -/*! - * Copyright 2018 XGBoost contributors - */ -// Dummy file to keep the CUDA conditional compile trick. -#include "test_elementwise_metric.cc" \ No newline at end of file diff --git a/tests/cpp/metric/test_elementwise_metric.h b/tests/cpp/metric/test_elementwise_metric.h index ef34d765144b..70a106798698 100644 --- a/tests/cpp/metric/test_elementwise_metric.h +++ b/tests/cpp/metric/test_elementwise_metric.h @@ -5,10 +5,9 @@ #include #include -#include #include +#include // for iota -#include "../../../src/common/linalg_op.h" #include "../helpers.h" namespace xgboost::metric { @@ -43,8 +42,8 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) } } -inline void VerifyRMSE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyRMSE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("rmse", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "rmse"); @@ -69,11 +68,11 @@ inline void VerifyRMSE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.6708f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"rmse"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"rmse"}, device.ordinal); } -inline void VerifyRMSLE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyRMSLE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("rmsle", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "rmsle"); @@ -98,11 +97,11 @@ inline void VerifyRMSLE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.2415f, 1e-4); delete metric; - CheckDeterministicMetricElementWise(StringView{"rmsle"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"rmsle"}, device.ordinal); } -inline void VerifyMAE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMAE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("mae", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "mae"); @@ -127,11 +126,11 @@ inline void VerifyMAE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.54f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"mae"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"mae"}, device.ordinal); } -inline void VerifyMAPE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMAPE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("mape", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "mape"); @@ -156,11 +155,11 @@ inline void VerifyMAPE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 1.3250f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"mape"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"mape"}, device.ordinal); } -inline void VerifyMPHE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMPHE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{xgboost::Metric::Create("mphe", &ctx)}; metric->Configure({}); ASSERT_STREQ(metric->Name(), "mphe"); @@ -184,7 +183,7 @@ inline void VerifyMPHE(DataSplitMode data_split_mode = DataSplitMode::kRow) { { 1, 2, 9, 8}, {}, data_split_mode), 0.1922f, 1e-4); - CheckDeterministicMetricElementWise(StringView{"mphe"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"mphe"}, device.ordinal); metric->Configure({{"huber_slope", "0.1"}}); EXPECT_NEAR(GetMetricEval(metric.get(), @@ -194,8 +193,8 @@ inline void VerifyMPHE(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.0461686f, 1e-4); } -inline void VerifyLogLoss(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyLogLoss(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("logloss", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "logloss"); @@ -224,11 +223,11 @@ inline void VerifyLogLoss(DataSplitMode data_split_mode = DataSplitMode::kRow) { 1.3138f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"logloss"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"logloss"}, device.ordinal); } -inline void VerifyError(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyError(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("error", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "error"); @@ -286,11 +285,11 @@ inline void VerifyError(DataSplitMode data_split_mode = DataSplitMode::kRow) { 0.45f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"error@0.5"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"error@0.5"}, device.ordinal); } -inline void VerifyPoissonNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyPoissonNegLogLik(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("poisson-nloglik", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "poisson-nloglik"); @@ -319,11 +318,11 @@ inline void VerifyPoissonNegLogLik(DataSplitMode data_split_mode = DataSplitMode 1.5783f, 0.001f); delete metric; - CheckDeterministicMetricElementWise(StringView{"poisson-nloglik"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"poisson-nloglik"}, device.ordinal); } -inline void VerifyMultiRMSE(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMultiRMSE(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); size_t n_samples = 32, n_targets = 8; linalg::Tensor y{{n_samples, n_targets}, ctx.Device()}; auto &h_y = y.Data()->HostVector(); @@ -344,8 +343,8 @@ inline void VerifyMultiRMSE(DataSplitMode data_split_mode = DataSplitMode::kRow) ASSERT_FLOAT_EQ(ret, loss_w); } -inline void VerifyQuantile(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyQuantile(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("quantile", &ctx)}; HostDeviceVector predts{0.1f, 0.9f, 0.1f, 0.9f}; diff --git a/tests/cpp/metric/test_multiclass_metric.cc b/tests/cpp/metric/test_multiclass_metric.cc deleted file mode 100644 index 7fc8bc42934b..000000000000 --- a/tests/cpp/metric/test_multiclass_metric.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright by Contributors -#include "test_multiclass_metric.h" - -#include - -namespace xgboost { -namespace metric { - -TEST(Metric, DeclareUnifiedTest(MultiClassError)) { VerifyMultiClassError(); } - -TEST(Metric, DeclareUnifiedTest(MultiClassLogLoss)) { VerifyMultiClassLogLoss(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassErrorRowSplit) { - DoTest(VerifyMultiClassError, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassErrorColumnSplit) { - DoTest(VerifyMultiClassError, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassLogLossRowSplit) { - DoTest(VerifyMultiClassLogLoss, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MultiClassLogLossColumnSplit) { - DoTest(VerifyMultiClassLogLoss, DataSplitMode::kCol); -} -} // namespace metric -} // namespace xgboost diff --git a/tests/cpp/metric/test_multiclass_metric.cu b/tests/cpp/metric/test_multiclass_metric.cu deleted file mode 100644 index 8a087565b3da..000000000000 --- a/tests/cpp/metric/test_multiclass_metric.cu +++ /dev/null @@ -1,5 +0,0 @@ -/*! - * Copyright 2019 XGBoost contributors - */ -// Dummy file to keep the CUDA conditional compile trick. -#include "test_multiclass_metric.cc" \ No newline at end of file diff --git a/tests/cpp/metric/test_multiclass_metric.h b/tests/cpp/metric/test_multiclass_metric.h index 5fdead596ad2..002e38cb1e21 100644 --- a/tests/cpp/metric/test_multiclass_metric.h +++ b/tests/cpp/metric/test_multiclass_metric.h @@ -44,8 +44,8 @@ inline void CheckDeterministicMetricMultiClass(StringView name, int32_t device) } } -inline void TestMultiClassError(int device, DataSplitMode data_split_mode) { - auto ctx = MakeCUDACtx(device); +inline void TestMultiClassError(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("merror", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "merror"); @@ -59,13 +59,13 @@ inline void TestMultiClassError(int device, DataSplitMode data_split_mode) { delete metric; } -inline void VerifyMultiClassError(DataSplitMode data_split_mode = DataSplitMode::kRow) { - TestMultiClassError(GPUIDX, data_split_mode); - CheckDeterministicMetricMultiClass(StringView{"merror"}, GPUIDX); +inline void VerifyMultiClassError(DataSplitMode data_split_mode, DeviceOrd device) { + TestMultiClassError(data_split_mode, device); + CheckDeterministicMetricMultiClass(StringView{"merror"}, device.ordinal); } -inline void TestMultiClassLogLoss(int device, DataSplitMode data_split_mode) { - auto ctx = MakeCUDACtx(device); +inline void TestMultiClassLogLoss(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); xgboost::Metric * metric = xgboost::Metric::Create("mlogloss", &ctx); metric->Configure({}); ASSERT_STREQ(metric->Name(), "mlogloss"); @@ -80,9 +80,9 @@ inline void TestMultiClassLogLoss(int device, DataSplitMode data_split_mode) { delete metric; } -inline void VerifyMultiClassLogLoss(DataSplitMode data_split_mode = DataSplitMode::kRow) { - TestMultiClassLogLoss(GPUIDX, data_split_mode); - CheckDeterministicMetricMultiClass(StringView{"mlogloss"}, GPUIDX); +inline void VerifyMultiClassLogLoss(DataSplitMode data_split_mode, DeviceOrd device) { + TestMultiClassLogLoss(data_split_mode, device); + CheckDeterministicMetricMultiClass(StringView{"mlogloss"}, device.ordinal); } } // namespace metric diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index fbf0611b3f06..4c69847f8396 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -1,84 +1,29 @@ /** - * Copyright 2016-2023 by XGBoost Contributors + * Copyright 2016-2023, XGBoost Contributors */ -#include // for Test, EXPECT_NEAR, ASSERT_STREQ -#include // for Context -#include // for MetaInfo, DMatrix -#include // for Matrix -#include // for Metric +#include "test_rank_metric.h" -#include // for max -#include // for unique_ptr -#include // for vector +#include // for Test, EXPECT_NEAR, ASSERT_STREQ +#include // for Context +#include // for Metric -#include "test_rank_metric.h" -#include "../helpers.h" // for GetMetricEval, CreateEmptyGe... -#include "xgboost/base.h" // for bst_float, kRtEps -#include "xgboost/host_device_vector.h" // for HostDeviceVector -#include "xgboost/json.h" // for Json, String, Object +#include // for unique_ptr -namespace xgboost { -namespace metric { +#include "../helpers.h" // for GetMetricEval, CreateEmptyGe... +#include "xgboost/base.h" // for bst_float, kRtEps -#if !defined(__CUDACC__) +namespace xgboost::metric { TEST(Metric, AMS) { auto ctx = MakeCUDACtx(GPUIDX); EXPECT_ANY_THROW(Metric::Create("ams", &ctx)); - Metric* metric = Metric::Create("ams@0.5f", &ctx); + std::unique_ptr metric{Metric::Create("ams@0.5f", &ctx)}; ASSERT_STREQ(metric->Name(), "ams@0.5"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0.311f, 0.001f); - EXPECT_NEAR(GetMetricEval(metric, - {0.1f, 0.9f, 0.1f, 0.9f}, - { 0, 0, 1, 1}), - 0.29710f, 0.001f); + EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}), 0.311f, 0.001f); + EXPECT_NEAR(GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}), 0.29710f, + 0.001f); - delete metric; - metric = Metric::Create("ams@0", &ctx); + metric.reset(Metric::Create("ams@0", &ctx)); ASSERT_STREQ(metric->Name(), "ams@0"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 0.311f, 0.001f); - - delete metric; -} -#endif - -TEST(Metric, DeclareUnifiedTest(Precision)) { VerifyPrecision(); } - -TEST(Metric, DeclareUnifiedTest(NDCG)) { VerifyNDCG(); } - -TEST(Metric, DeclareUnifiedTest(MAP)) { VerifyMAP(); } - -TEST(Metric, DeclareUnifiedTest(NDCGExpGain)) { VerifyNDCGExpGain(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PrecisionRowSplit) { - DoTest(VerifyPrecision, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), PrecisionColumnSplit) { - DoTest(VerifyPrecision, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), NDCGRowSplit) { - DoTest(VerifyNDCG, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), NDCGColumnSplit) { - DoTest(VerifyNDCG, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAPRowSplit) { - DoTest(VerifyMAP, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), MAPColumnSplit) { - DoTest(VerifyMAP, DataSplitMode::kCol); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), NDCGExpGainRowSplit) { - DoTest(VerifyNDCGExpGain, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), NDCGExpGainColumnSplit) { - DoTest(VerifyNDCGExpGain, DataSplitMode::kCol); + EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}), 0.311f, 0.001f); } -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/tests/cpp/metric/test_rank_metric.cu b/tests/cpp/metric/test_rank_metric.cu deleted file mode 100644 index 38b4c72e196a..000000000000 --- a/tests/cpp/metric/test_rank_metric.cu +++ /dev/null @@ -1,5 +0,0 @@ -/*! - * Copyright 2019 XGBoost contributors - */ -// Dummy file to keep the CUDA conditional compile trick. -#include "test_rank_metric.cc" diff --git a/tests/cpp/metric/test_rank_metric.h b/tests/cpp/metric/test_rank_metric.h index 5d5e87072937..bb409628868a 100644 --- a/tests/cpp/metric/test_rank_metric.h +++ b/tests/cpp/metric/test_rank_metric.h @@ -19,8 +19,8 @@ namespace xgboost::metric { -inline void VerifyPrecision(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyPrecision(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); std::unique_ptr metric{Metric::Create("pre", &ctx)}; ASSERT_STREQ(metric->Name(), "pre"); EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5, 1e-7); @@ -43,8 +43,8 @@ inline void VerifyPrecision(DataSplitMode data_split_mode = DataSplitMode::kRow) 0.5f, 1e-7); } -inline void VerifyNDCG(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyNDCG(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); Metric * metric = xgboost::Metric::Create("ndcg", &ctx); ASSERT_STREQ(metric->Name(), "ndcg"); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}, {}, {}, data_split_mode)); @@ -101,8 +101,8 @@ inline void VerifyNDCG(DataSplitMode data_split_mode = DataSplitMode::kRow) { delete metric; } -inline void VerifyMAP(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyMAP(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); Metric * metric = xgboost::Metric::Create("map", &ctx); ASSERT_STREQ(metric->Name(), "map"); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 1, kRtEps); @@ -149,8 +149,8 @@ inline void VerifyMAP(DataSplitMode data_split_mode = DataSplitMode::kRow) { delete metric; } -inline void VerifyNDCGExpGain(DataSplitMode data_split_mode = DataSplitMode::kRow) { - Context ctx = MakeCUDACtx(GPUIDX); +inline void VerifyNDCGExpGain(DataSplitMode data_split_mode, DeviceOrd device) { + Context ctx = MakeCUDACtx(device.ordinal); auto p_fmat = xgboost::RandomDataGenerator{0, 0, 0}.GenerateDMatrix(); MetaInfo& info = p_fmat->Info(); diff --git a/tests/cpp/metric/test_survival_metric.cc b/tests/cpp/metric/test_survival_metric.cc index ded9c4b0e2de..1b02fd7e7b02 100644 --- a/tests/cpp/metric/test_survival_metric.cc +++ b/tests/cpp/metric/test_survival_metric.cc @@ -1,5 +1,5 @@ -/*! - * Copyright (c) by Contributors 2020 +/** + * Copyright 2020-2023, XGBoost Contributors */ #include #include @@ -16,8 +16,7 @@ // CUDA conditional compile trick. #include "test_survival_metric.cu" -namespace xgboost { -namespace common { +namespace xgboost::common { /** Tests for Survival metrics that should run only on CPU **/ @@ -113,6 +112,4 @@ TEST(AFTLoss, IntervalCensored) { { 8.0000, 4.8004, 2.8805, 1.7284, 1.0372, 0.6231, 0.3872, 0.3031, 0.3740, 0.5839, 0.8995, 1.2878, 1.7231, 2.1878, 2.6707, 3.1647, 3.6653, 4.1699, 4.6770, 5.1856 }); } - -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/tests/cpp/metric/test_survival_metric.cu b/tests/cpp/metric/test_survival_metric.cu index eec92dc990a8..ead8d11f2a3b 100644 --- a/tests/cpp/metric/test_survival_metric.cu +++ b/tests/cpp/metric/test_survival_metric.cu @@ -7,28 +7,7 @@ /** Tests for Survival metrics that should run both on CPU and GPU **/ -namespace xgboost { -namespace common { -TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) { VerifyAFTNegLogLik(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikRowSplit) { - DoTest(VerifyAFTNegLogLik, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), AFTNegLogLikColumnSplit) { - DoTest(VerifyAFTNegLogLik, DataSplitMode::kCol); -} - -TEST(Metric, DeclareUnifiedTest(IntervalRegressionAccuracy)) { VerifyIntervalRegressionAccuracy(); } - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), IntervalRegressionAccuracyRowSplit) { - DoTest(VerifyIntervalRegressionAccuracy, DataSplitMode::kRow); -} - -TEST_F(DeclareUnifiedDistributedTest(MetricTest), IntervalRegressionAccuracyColumnSplit) { - DoTest(VerifyIntervalRegressionAccuracy, DataSplitMode::kCol); -} - +namespace xgboost::common { // Test configuration of AFT metric TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) { auto ctx = MakeCUDACtx(GPUIDX); @@ -44,5 +23,4 @@ TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) { CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX); } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/tests/cpp/metric/test_survival_metric.h b/tests/cpp/metric/test_survival_metric.h index 1626d37724ed..902c9aa6bbcf 100644 --- a/tests/cpp/metric/test_survival_metric.h +++ b/tests/cpp/metric/test_survival_metric.h @@ -47,8 +47,8 @@ inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) } } -inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); /** * Test aggregate output from the AFT metric over a small test data set. @@ -78,8 +78,8 @@ inline void VerifyAFTNegLogLik(DataSplitMode data_split_mode = DataSplitMode::kR } } -inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = DataSplitMode::kRow) { - auto ctx = MakeCUDACtx(GPUIDX); +inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode, DeviceOrd device) { + auto ctx = MakeCUDACtx(device.ordinal); auto p_fmat = EmptyDMatrix(); MetaInfo& info = p_fmat->Info(); @@ -101,7 +101,7 @@ inline void VerifyIntervalRegressionAccuracy(DataSplitMode data_split_mode = Dat info.labels_lower_bound_.HostVector()[0] = 70.0f; EXPECT_FLOAT_EQ(metric->Evaluate(preds, p_fmat), 0.25f); - CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX); + CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, device.ordinal); } } // namespace common } // namespace xgboost diff --git a/tests/cpp/objective/test_objective.cc b/tests/cpp/objective/test_objective.cc index 21ffc7cafa37..efdd03612a0f 100644 --- a/tests/cpp/objective/test_objective.cc +++ b/tests/cpp/objective/test_objective.cc @@ -50,7 +50,7 @@ class TestDefaultObjConfig : public ::testing::TestWithParam { public: void Run(std::string objective) { - auto Xy = MakeFmatForObjTest(objective); + auto Xy = MakeFmatForObjTest(objective, 10, 10); std::unique_ptr learner{Learner::Create({Xy})}; std::unique_ptr objfn{ObjFunction::Create(objective, &ctx_)}; diff --git a/tests/cpp/objective/test_regression_obj_cpu.cc b/tests/cpp/objective/test_regression_obj_cpu.cc index 3613d0d901bc..18ee4db7eede 100644 --- a/tests/cpp/objective/test_regression_obj_cpu.cc +++ b/tests/cpp/objective/test_regression_obj_cpu.cc @@ -1,14 +1,15 @@ -/*! - * Copyright 2018-2023 XGBoost contributors +/** + * Copyright 2018-2024, XGBoost contributors */ #include #include #include +#include // for iota + #include "../../../src/objective/adaptive.h" -#include "../../../src/tree/param.h" // for TrainParam +#include "../../../src/tree/param.h" // for TrainParam #include "../helpers.h" - #include "test_regression_obj.h" namespace xgboost { diff --git a/tests/cpp/objective_helpers.cc b/tests/cpp/objective_helpers.cc index ed80f71d512f..9ad4b5c39688 100644 --- a/tests/cpp/objective_helpers.cc +++ b/tests/cpp/objective_helpers.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #include "objective_helpers.h" @@ -7,17 +7,17 @@ #include "helpers.h" // for RandomDataGenerator namespace xgboost { -std::shared_ptr MakeFmatForObjTest(std::string const& obj) { - auto constexpr kRows = 10, kCols = 10; - auto p_fmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); + +void MakeLabelForObjTest(std::shared_ptr p_fmat, std::string const& obj) { auto& h_upper = p_fmat->Info().labels_upper_bound_.HostVector(); auto& h_lower = p_fmat->Info().labels_lower_bound_.HostVector(); - h_lower.resize(kRows); - h_upper.resize(kRows); - for (size_t i = 0; i < kRows; ++i) { + h_lower.resize(p_fmat->Info().num_row_); + h_upper.resize(p_fmat->Info().num_row_); + for (size_t i = 0; i < p_fmat->Info().num_row_; ++i) { h_lower[i] = 1; h_upper[i] = 10; } + if (obj.find("rank:") != std::string::npos) { auto h_label = p_fmat->Info().labels.HostView(); std::size_t k = 0; @@ -26,6 +26,12 @@ std::shared_ptr MakeFmatForObjTest(std::string const& obj) { ++k; } } +} + +std::shared_ptr MakeFmatForObjTest(std::string const& obj, bst_idx_t n_samples, + bst_feature_t n_features) { + auto p_fmat = RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true); + MakeLabelForObjTest(p_fmat, obj); return p_fmat; }; } // namespace xgboost diff --git a/tests/cpp/objective_helpers.h b/tests/cpp/objective_helpers.h index 7f394ef8d523..972747c36e21 100644 --- a/tests/cpp/objective_helpers.h +++ b/tests/cpp/objective_helpers.h @@ -32,5 +32,11 @@ inline std::string ObjTestNameGenerator(const ::testing::TestParamInfo MakeFmatForObjTest(std::string const& obj); +/** + * @brief Construct random label for testing. + */ +void MakeLabelForObjTest(std::shared_ptr p_fmat, std::string const& obj); + +std::shared_ptr MakeFmatForObjTest(std::string const& obj, bst_idx_t n_samples, + bst_feature_t n_features); } // namespace xgboost diff --git a/tests/cpp/plugin/federated/test_federated_coll.cc b/tests/cpp/plugin/federated/test_federated_coll.cc index ad053f286cd4..6b7000ef926b 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cc +++ b/tests/cpp/plugin/federated/test_federated_coll.cc @@ -60,8 +60,7 @@ TEST_F(FederatedCollTest, Allgather) { std::vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); - auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), - sizeof(int)); + auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()})); ASSERT_TRUE(rc.OK()); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu index a6ec7e352192..008952a4f629 100644 --- a/tests/cpp/plugin/federated/test_federated_coll.cu +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -5,13 +5,13 @@ #include #include // for Result +#include "../../../../src/collective/allreduce.h" #include "../../../../src/common/common.h" // for AllVisibleGPUs #include "../../../../src/common/device_helpers.cuh" // for device_vector #include "../../../../src/common/type.h" // for EraseType #include "../../collective/test_worker.h" // for SocketTest #include "../../helpers.h" // for MakeCUDACtx #include "federated_coll.cuh" -#include "federated_comm.cuh" #include "test_worker.h" // for TestFederated namespace xgboost::collective { @@ -71,7 +71,7 @@ void TestAllgather(std::shared_ptr comm, std::int32_t rank, std:: dh::device_vector buffer(n_workers, 0); buffer[comm->Rank()] = comm->Rank(); - auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), sizeof(int)); + auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer))); ASSERT_TRUE(rc.OK()); for (auto i = 0; i < n_workers; i++) { ASSERT_EQ(buffer[i], i); @@ -108,6 +108,32 @@ TEST_F(FederatedCollTestGPU, Allreduce) { }); } +TEST(FederatedCollGPUGlobal, Allreduce) { + std::int32_t n_workers = common::AllVisibleGPUs(); + TestFederatedGlobal(n_workers, [&] { + auto r = collective::GetRank(); + auto world = collective::GetWorldSize(); + CHECK_EQ(n_workers, world); + + dh::device_vector values(3, r); + auto ctx = MakeCUDACtx(r); + auto rc = collective::Allreduce( + &ctx, linalg::MakeVec(values.data().get(), values.size(), DeviceOrd::CUDA(r)), + Op::kBitwiseOR); + SafeColl(rc); + + std::vector expected(values.size(), 0); + for (std::int32_t rank = 0; rank < world; ++rank) { + for (std::size_t i = 0; i < expected.size(); ++i) { + expected[i] |= rank; + } + } + for (std::size_t i = 0; i < expected.size(); ++i) { + CHECK_EQ(expected[i], values[i]); + } + }); +} + TEST_F(FederatedCollTestGPU, Broadcast) { std::int32_t n_workers = common::AllVisibleGPUs(); TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t rank) { diff --git a/tests/cpp/plugin/federated/test_federated_comm.cc b/tests/cpp/plugin/federated/test_federated_comm.cc index 0d0692b5f59c..16edc685fde3 100644 --- a/tests/cpp/plugin/federated/test_federated_comm.cc +++ b/tests/cpp/plugin/federated/test_federated_comm.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023, XGBoost contributors + * Copyright 2022-2024, XGBoost contributors */ #include #include @@ -9,7 +9,7 @@ #include "../../../../plugin/federated/federated_comm.h" #include "../../collective/test_worker.h" // for SocketTest -#include "../../helpers.h" // for ExpectThrow +#include "../../helpers.h" // for GMockThrow #include "test_worker.h" // for TestFederated #include "xgboost/json.h" // for Json @@ -20,19 +20,19 @@ class FederatedCommTest : public SocketTest {}; TEST_F(FederatedCommTest, ThrowOnWorldSizeTooSmall) { auto construct = [] { FederatedComm comm{"localhost", 0, 0, 0}; }; - ASSERT_THAT(construct, - ::testing::ThrowsMessage(::testing::HasSubstr("Invalid world size"))); + ASSERT_THAT(construct, GMockThrow("Invalid world size")); } TEST_F(FederatedCommTest, ThrowOnRankTooSmall) { auto construct = [] { FederatedComm comm{"localhost", 0, 1, -1}; }; - ASSERT_THAT(construct, - ::testing::ThrowsMessage(::testing::HasSubstr("Invalid worker rank."))); + ASSERT_THAT(construct, GMockThrow("Invalid worker rank.")); } TEST_F(FederatedCommTest, ThrowOnRankTooBig) { - auto construct = [] { FederatedComm comm{"localhost", 0, 1, 1}; }; - ExpectThrow("Invalid worker rank.", construct); + auto construct = [] { + FederatedComm comm{"localhost", 0, 1, 1}; + }; + ASSERT_THAT(construct, GMockThrow("Invalid worker rank.")); } TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) { @@ -43,7 +43,7 @@ TEST_F(FederatedCommTest, ThrowOnWorldSizeNotInteger) { config["federated_rank"] = Integer(0); FederatedComm comm{DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config}; }; - ExpectThrow("got: `String`", construct); + ASSERT_THAT(construct, GMockThrow("got: `String`")); } TEST_F(FederatedCommTest, ThrowOnRankNotInteger) { @@ -54,7 +54,7 @@ TEST_F(FederatedCommTest, ThrowOnRankNotInteger) { config["federated_rank"] = std::string("0"); FederatedComm comm(DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, "", config); }; - ExpectThrow("got: `String`", construct); + ASSERT_THAT(construct, GMockThrow("got: `String`")); } TEST_F(FederatedCommTest, GetWorldSizeAndRank) { diff --git a/tests/cpp/plugin/federated/test_federated_tracker.cc b/tests/cpp/plugin/federated/test_federated_tracker.cc index 81ff95540b15..aa979ff15348 100644 --- a/tests/cpp/plugin/federated/test_federated_tracker.cc +++ b/tests/cpp/plugin/federated/test_federated_tracker.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023, XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include @@ -8,7 +8,6 @@ #include "../../../../src/collective/tracker.h" // for GetHostAddress #include "federated_tracker.h" -#include "test_worker.h" #include "xgboost/json.h" // for Json namespace xgboost::collective { @@ -26,7 +25,7 @@ TEST(FederatedTrackerTest, Basic) { ASSERT_GE(tracker->Port(), 1); std::string host; auto rc = GetHostAddress(&host); - ASSERT_EQ(get(args["DMLC_TRACKER_URI"]), host); + ASSERT_EQ(get(args["dmlc_tracker_uri"]), host); rc = tracker->Shutdown(); ASSERT_TRUE(rc.OK()); diff --git a/tests/cpp/plugin/federated/test_worker.h b/tests/cpp/plugin/federated/test_worker.h index d0edecc15b7e..8ec76237df5c 100644 --- a/tests/cpp/plugin/federated/test_worker.h +++ b/tests/cpp/plugin/federated/test_worker.h @@ -11,12 +11,24 @@ #include "../../../../plugin/federated/federated_tracker.h" #include "../../../../src/collective/comm_group.h" +#include "../../../../src/collective/communicator-inl.h" #include "federated_comm.h" // for FederatedComm #include "xgboost/json.h" // for Json namespace xgboost::collective { +inline Json FederatedTestConfig(std::int32_t n_workers, std::int32_t port, std::int32_t i) { + Json config{Object{}}; + config["dmlc_communicator"] = std::string{"federated"}; + config["dmlc_task_id"] = std::to_string(i); + config["dmlc_retry"] = 2; + config["federated_world_size"] = n_workers; + config["federated_rank"] = i; + config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); + return config; +} + template -void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { +void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) { Json config{Object()}; config["federated_secure"] = Boolean{false}; config["n_workers"] = Integer{n_workers}; @@ -30,16 +42,7 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { std::int32_t port = tracker.Port(); for (std::int32_t i = 0; i < n_workers; ++i) { - workers.emplace_back([=] { - Json config{Object{}}; - config["federated_world_size"] = n_workers; - config["federated_rank"] = i; - config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); - auto comm = std::make_shared( - DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config); - - fn(comm, i); - }); + workers.emplace_back([=] { fn(port, i); }); } for (auto& t : workers) { @@ -52,38 +55,32 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { } template -void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) { - Json config{Object()}; - config["federated_secure"] = Boolean{false}; - config["n_workers"] = Integer{n_workers}; - FederatedTracker tracker{config}; - auto fut = tracker.Run(); - - std::vector workers; - auto rc = tracker.WaitUntilReady(); - ASSERT_TRUE(rc.OK()) << rc.Report(); - std::int32_t port = tracker.Port(); +void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { + TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) { + auto config = FederatedTestConfig(n_workers, port, i); + auto comm = std::make_shared( + DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config); - for (std::int32_t i = 0; i < n_workers; ++i) { - workers.emplace_back([=] { - Json config{Object{}}; - config["dmlc_communicator"] = std::string{"federated"}; - config["dmlc_task_id"] = std::to_string(i); - config["dmlc_retry"] = 2; - config["federated_world_size"] = n_workers; - config["federated_rank"] = i; - config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); - std::shared_ptr comm_group{CommGroup::Create(config)}; - fn(comm_group, i); - }); - } + fn(comm, i); + }); +} - for (auto& t : workers) { - t.join(); - } +template +void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) { + TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) { + auto config = FederatedTestConfig(n_workers, port, i); + std::shared_ptr comm_group{CommGroup::Create(config)}; + fn(comm_group, i); + }); +} - rc = tracker.Shutdown(); - ASSERT_TRUE(rc.OK()) << rc.Report(); - ASSERT_TRUE(fut.get().OK()); +template +void TestFederatedGlobal(std::int32_t n_workers, WorkerFn&& fn) { + TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) { + auto config = FederatedTestConfig(n_workers, port, i); + collective::Init(config); + fn(); + collective::Finalize(); + }); } } // namespace xgboost::collective diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h deleted file mode 100644 index 85f2e014bdf2..000000000000 --- a/tests/cpp/plugin/helpers.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2022-2023, XGBoost contributors - */ -#pragma once - -#include -#include -#include -#include - -#include -#include // for thread, sleep_for - -#include "../../../plugin/federated/federated_server.h" -#include "../../../src/collective/communicator-inl.h" -#include "../../../src/common/threading_utils.h" - -namespace xgboost { - -class ServerForTest { - std::string server_address_; - std::unique_ptr server_thread_; - std::unique_ptr server_; - - public: - explicit ServerForTest(std::size_t world_size) { - server_thread_.reset(new std::thread([this, world_size] { - grpc::ServerBuilder builder; - xgboost::federated::FederatedService service{static_cast(world_size)}; - int selected_port; - builder.AddListeningPort("localhost:0", grpc::InsecureServerCredentials(), &selected_port); - builder.RegisterService(&service); - server_ = builder.BuildAndStart(); - server_address_ = std::string("localhost:") + std::to_string(selected_port); - server_->Wait(); - })); - } - - ~ServerForTest() { - using namespace std::chrono_literals; - while (!server_) { - std::this_thread::sleep_for(100ms); - } - server_->Shutdown(); - while (!server_thread_) { - std::this_thread::sleep_for(100ms); - } - server_thread_->join(); - } - - auto Address() const { - using namespace std::chrono_literals; - while (server_address_.empty()) { - std::this_thread::sleep_for(100ms); - } - return server_address_; - } -}; - -class BaseFederatedTest : public ::testing::Test { - protected: - void SetUp() override { server_ = std::make_unique(kWorldSize); } - - void TearDown() override { server_.reset(nullptr); } - - static int constexpr kWorldSize{2}; - std::unique_ptr server_; -}; - -template -void RunWithFederatedCommunicator(int32_t world_size, std::string const& server_address, - Function&& function, Args&&... args) { - auto run = [&](auto rank) { - Json config{JsonObject()}; - config["xgboost_communicator"] = String("federated"); - config["federated_secure"] = false; - config["federated_server_address"] = String(server_address); - config["federated_world_size"] = world_size; - config["federated_rank"] = rank; - xgboost::collective::Init(config); - - std::forward(function)(std::forward(args)...); - - xgboost::collective::Finalize(); - }; -#if defined(_OPENMP) - common::ParallelFor(world_size, world_size, run); -#else - std::vector threads; - for (auto rank = 0; rank < world_size; rank++) { - threads.emplace_back(run, rank); - } - for (auto& thread : threads) { - thread.join(); - } -#endif -} - -} // namespace xgboost diff --git a/tests/cpp/plugin/sycl_helpers.h b/tests/cpp/plugin/sycl_helpers.h index c5cdd3ea5b08..afc403d86333 100644 --- a/tests/cpp/plugin/sycl_helpers.h +++ b/tests/cpp/plugin/sycl_helpers.h @@ -8,22 +8,23 @@ namespace xgboost::sycl { template void VerifySyclVector(const USMVector& sycl_vector, - const Container& host_vector) { + const Container& host_vector, T eps = T()) { ASSERT_EQ(sycl_vector.Size(), host_vector.size()); size_t size = sycl_vector.Size(); for (size_t i = 0; i < size; ++i) { - ASSERT_EQ(sycl_vector[i], host_vector[i]); + EXPECT_NEAR(sycl_vector[i], host_vector[i], eps); } } template -void VerifySyclVector(const std::vector& sycl_vector, const Container& host_vector) { +void VerifySyclVector(const std::vector& sycl_vector, + const Container& host_vector, T eps = T()) { ASSERT_EQ(sycl_vector.size(), host_vector.size()); size_t size = sycl_vector.size(); for (size_t i = 0; i < size; ++i) { - ASSERT_EQ(sycl_vector[i], host_vector[i]); + EXPECT_NEAR(sycl_vector[i], host_vector[i], eps); } } diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu deleted file mode 100644 index cec180e703e0..000000000000 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ /dev/null @@ -1,98 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include - -#include -#include -#include - -#include "../../../plugin/federated/federated_communicator.h" -#include "../../../src/collective/communicator-inl.cuh" -#include "../../../src/collective/device_communicator_adapter.cuh" -#include "../helpers.h" -#include "./helpers.h" - -namespace xgboost::collective { - -class FederatedAdapterTest : public BaseFederatedTest {}; - -TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) { - auto construct = []() { DeviceCommunicatorAdapter adapter{-1}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -namespace { -void VerifyAllReduceSum() { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - auto const device = GPUIDX; - int count = 3; - common::SetDevice(device); - thrust::device_vector buffer(count, 0); - thrust::sequence(buffer.begin(), buffer.end()); - collective::AllReduce(device, buffer.data().get(), count); - thrust::host_vector host_buffer = buffer; - EXPECT_EQ(host_buffer.size(), count); - for (auto i = 0; i < count; i++) { - EXPECT_EQ(host_buffer[i], i * world_size); - } -} -} // anonymous namespace - -TEST_F(FederatedAdapterTest, MGPUAllReduceSum) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum); -} - -namespace { -void VerifyAllGather() { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - auto const device = GPUIDX; - common::SetDevice(device); - thrust::device_vector send_buffer(1, rank); - thrust::device_vector receive_buffer(world_size, 0); - collective::AllGather(device, send_buffer.data().get(), receive_buffer.data().get(), - sizeof(double)); - thrust::host_vector host_buffer = receive_buffer; - EXPECT_EQ(host_buffer.size(), world_size); - for (auto i = 0; i < world_size; i++) { - EXPECT_EQ(host_buffer[i], i); - } -} -} // anonymous namespace - -TEST_F(FederatedAdapterTest, MGPUAllGather) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGather); -} - -namespace { -void VerifyAllGatherV() { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); - auto const device = GPUIDX; - int const count = rank + 2; - common::SetDevice(device); - thrust::device_vector buffer(count, 0); - thrust::sequence(buffer.begin(), buffer.end()); - std::vector segments(world_size); - dh::caching_device_vector receive_buffer{}; - - collective::AllGatherV(device, buffer.data().get(), count, &segments, &receive_buffer); - - EXPECT_EQ(segments[0], 2); - EXPECT_EQ(segments[1], 3); - thrust::host_vector host_buffer = receive_buffer; - EXPECT_EQ(host_buffer.size(), 5); - int expected[] = {0, 1, 0, 1, 2}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(host_buffer[i], expected[i]); - } -} -} // anonymous namespace - -TEST_F(FederatedAdapterTest, MGPUAllGatherV) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGatherV); -} -} // namespace xgboost::collective diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc deleted file mode 100644 index 68b112f1c7b1..000000000000 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ /dev/null @@ -1,161 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include -#include - -#include -#include - -#include "../../../plugin/federated/federated_communicator.h" -#include "helpers.h" - -namespace xgboost::collective { - -class FederatedCommunicatorTest : public BaseFederatedTest { - public: - static void VerifyAllgather(int rank, const std::string &server_address) { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - CheckAllgather(comm, rank); - } - - static void VerifyAllgatherV(int rank, const std::string &server_address) { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - CheckAllgatherV(comm, rank); - } - - static void VerifyAllreduce(int rank, const std::string &server_address) { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - CheckAllreduce(comm); - } - - static void VerifyBroadcast(int rank, const std::string &server_address) { - FederatedCommunicator comm{kWorldSize, rank, server_address}; - CheckBroadcast(comm, rank); - } - - protected: - static void CheckAllgather(FederatedCommunicator &comm, int rank) { - std::string input{static_cast('0' + rank)}; - auto output = comm.AllGather(input); - for (auto i = 0; i < kWorldSize; i++) { - EXPECT_EQ(output[i], static_cast('0' + i)); - } - } - - static void CheckAllgatherV(FederatedCommunicator &comm, int rank) { - std::vector inputs{"Federated", " Learning!!!"}; - auto output = comm.AllGatherV(inputs[rank]); - EXPECT_EQ(output, "Federated Learning!!!"); - } - - static void CheckAllreduce(FederatedCommunicator &comm) { - int buffer[] = {1, 2, 3, 4, 5}; - comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); - int expected[] = {2, 4, 6, 8, 10}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(buffer[i], expected[i]); - } - } - - static void CheckBroadcast(FederatedCommunicator &comm, int rank) { - if (rank == 0) { - std::string buffer{"hello"}; - comm.Broadcast(&buffer[0], buffer.size(), 0); - EXPECT_EQ(buffer, "hello"); - } else { - std::string buffer{" "}; - comm.Broadcast(&buffer[0], buffer.size(), 0); - EXPECT_EQ(buffer, "hello"); - } - } -}; - -TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { - auto construct = [] { FederatedCommunicator comm{0, 0, "localhost:0", "", "", ""}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) { - auto construct = [] { FederatedCommunicator comm{1, -1, "localhost:0", "", "", ""}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) { - auto construct = [] { FederatedCommunicator comm{1, 1, "localhost:0", "", "", ""}; }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { - auto construct = [] { - Json config{JsonObject()}; - config["federated_server_address"] = std::string("localhost:0"); - config["federated_world_size"] = std::string("1"); - config["federated_rank"] = Integer(0); - FederatedCommunicator::Create(config); - }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { - auto construct = [] { - Json config{JsonObject()}; - config["federated_server_address"] = std::string("localhost:0"); - config["federated_world_size"] = 1; - config["federated_rank"] = std::string("0"); - FederatedCommunicator::Create(config); - }; - EXPECT_THROW(construct(), dmlc::Error); -} - -TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) { - FederatedCommunicator comm{6, 3, "localhost:0"}; - EXPECT_EQ(comm.GetWorldSize(), 6); - EXPECT_EQ(comm.GetRank(), 3); -} - -TEST(FederatedCommunicatorSimpleTest, IsDistributed) { - FederatedCommunicator comm{2, 1, "localhost:0"}; - EXPECT_TRUE(comm.IsDistributed()); -} - -TEST_F(FederatedCommunicatorTest, Allgather) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgather, rank, server_->Address()); - } - for (auto &thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedCommunicatorTest, AllgatherV) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgatherV, rank, server_->Address()); - } - for (auto &thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedCommunicatorTest, Allreduce) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_->Address()); - } - for (auto &thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedCommunicatorTest, Broadcast) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_->Address()); - } - for (auto &thread : threads) { - thread.join(); - } -} -} // namespace xgboost::collective diff --git a/tests/cpp/plugin/test_federated_data.cc b/tests/cpp/plugin/test_federated_data.cc index 6a8233a0fefe..d0f649152bd4 100644 --- a/tests/cpp/plugin/test_federated_data.cc +++ b/tests/cpp/plugin/test_federated_data.cc @@ -6,16 +6,13 @@ #include -#include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" #include "../filesystem.h" #include "../helpers.h" -#include "helpers.h" +#include "federated/test_worker.h" namespace xgboost { -class FederatedDataTest : public BaseFederatedTest {}; - void VerifyLoadUri() { auto const rank = collective::GetRank(); @@ -47,7 +44,8 @@ void VerifyLoadUri() { } } -TEST_F(FederatedDataTest, LoadUri) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLoadUri); +TEST(FederatedDataTest, LoadUri) { + static int constexpr kWorldSize{2}; + collective::TestFederatedGlobal(kWorldSize, [] { VerifyLoadUri(); }); } } // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_learner.cc b/tests/cpp/plugin/test_federated_learner.cc index a9adedc63683..948914e0fa6b 100644 --- a/tests/cpp/plugin/test_federated_learner.cc +++ b/tests/cpp/plugin/test_federated_learner.cc @@ -1,17 +1,19 @@ -/*! - * Copyright 2023 XGBoost contributors +/** + * Copyright 2023-2024, XGBoost contributors + * + * Some other tests for federated learning are in the main test suite (test_learner.cc), + * gaurded by the `XGBOOST_USE_FEDERATED`. */ #include #include #include #include -#include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/communicator-inl.h" -#include "../../../src/common/linalg_op.h" +#include "../../../src/common/linalg_op.h" // for begin, end #include "../helpers.h" #include "../objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator -#include "helpers.h" +#include "federated/test_worker.h" namespace xgboost { namespace { @@ -36,32 +38,16 @@ auto MakeModel(std::string tree_method, std::string device, std::string objectiv return model; } -void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model, - std::string tree_method, std::string device, std::string objective) { - auto const world_size = collective::GetWorldSize(); - auto const rank = collective::GetRank(); +void VerifyObjective(std::size_t rows, std::size_t cols, float expected_base_score, + Json expected_model, std::string const &tree_method, std::string device, + std::string const &objective) { + auto rank = collective::GetRank(); std::shared_ptr dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; if (rank == 0) { - auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); - auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); - h_lower.resize(rows); - h_upper.resize(rows); - for (size_t i = 0; i < rows; ++i) { - h_lower[i] = 1; - h_upper[i] = 10; - } - - if (objective.find("rank:") != std::string::npos) { - auto h_label = dmat->Info().labels.HostView(); - std::size_t k = 0; - for (auto &v : h_label) { - v = k % 2 == 0; - ++k; - } - } + MakeLabelForObjTest(dmat, objective); } - std::shared_ptr sliced{dmat->SliceCol(world_size, rank)}; + std::shared_ptr sliced{dmat->SliceCol(collective::GetWorldSize(), rank)}; auto model = MakeModel(tree_method, device, objective, sliced); auto base_score = GetBaseScore(model); @@ -71,18 +57,15 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e } // namespace class VerticalFederatedLearnerTest : public ::testing::TestWithParam { - std::unique_ptr server_; static int constexpr kWorldSize{3}; protected: - void SetUp() override { server_ = std::make_unique(kWorldSize); } - void TearDown() override { server_.reset(nullptr); } - void Run(std::string tree_method, std::string device, std::string objective) { static auto constexpr kRows{16}; static auto constexpr kCols{16}; std::shared_ptr dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; + MakeLabelForObjTest(dmat, objective); auto &h_upper = dmat->Info().labels_upper_bound_.HostVector(); auto &h_lower = dmat->Info().labels_lower_bound_.HostVector(); @@ -103,9 +86,9 @@ class VerticalFederatedLearnerTest : public ::testing::TestWithParamAddress(), &VerifyObjective, kRows, kCols, - score, model, tree_method, device, objective); + collective::TestFederatedGlobal(kWorldSize, [&]() { + VerifyObjective(kRows, kCols, score, model, tree_method, device, objective); + }); } }; diff --git a/tests/cpp/plugin/test_federated_metrics.cc b/tests/cpp/plugin/test_federated_metrics.cc deleted file mode 100644 index 1bdda567f841..000000000000 --- a/tests/cpp/plugin/test_federated_metrics.cc +++ /dev/null @@ -1,243 +0,0 @@ -/*! - * Copyright 2023 XGBoost contributors - */ -#include - -#include "../metric/test_auc.h" -#include "../metric/test_elementwise_metric.h" -#include "../metric/test_multiclass_metric.h" -#include "../metric/test_rank_metric.h" -#include "../metric/test_survival_metric.h" -#include "helpers.h" - -namespace { -class FederatedMetricTest : public xgboost::BaseFederatedTest {}; -} // anonymous namespace - -namespace xgboost { -namespace metric { -TEST_F(FederatedMetricTest, BinaryAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, BinaryAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyBinaryAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiClassAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiClassAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, RankingAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, RankingAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, PRAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, PRAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPRAUC, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiClassPRAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiClassPRAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassPRAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, RankingPRAUCRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, RankingPRAUCColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRankingPRAUC, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, RMSERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, RMSEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, RMSLERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, RMSLEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyRMSLE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MAERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MAEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MAPERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MAPEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAPE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MPHERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MPHEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMPHE, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, LogLossRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, LogLossColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyLogLoss, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, ErrorRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, ErrorColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyError, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, PoissonNegLogLikRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, PoissonNegLogLikColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPoissonNegLogLik, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiRMSERowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiRMSEColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiRMSE, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, QuantileRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, QuantileColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyQuantile, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiClassErrorRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiClassErrorColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassError, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MultiClassLogLossRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MultiClassLogLossColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMultiClassLogLoss, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, PrecisionRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, PrecisionColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyPrecision, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, NDCGRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, NDCGColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCG, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, MAPRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, MAPColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyMAP, DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, NDCGExpGainRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, NDCGExpGainColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyNDCGExpGain, - DataSplitMode::kCol); -} -} // namespace metric -} // namespace xgboost - -namespace xgboost { -namespace common { -TEST_F(FederatedMetricTest, AFTNegLogLikRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, AFTNegLogLikColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAFTNegLogLik, - DataSplitMode::kCol); -} - -TEST_F(FederatedMetricTest, IntervalRegressionAccuracyRowSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy, - DataSplitMode::kRow); -} - -TEST_F(FederatedMetricTest, IntervalRegressionAccuracyColumnSplit) { - RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyIntervalRegressionAccuracy, - DataSplitMode::kCol); -} -} // namespace common -} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc deleted file mode 100644 index c40e58fa388f..000000000000 --- a/tests/cpp/plugin/test_federated_server.cc +++ /dev/null @@ -1,133 +0,0 @@ -/*! - * Copyright 2017-2020 XGBoost contributors - */ -#include - -#include -#include - -#include "federated_client.h" -#include "helpers.h" - -namespace xgboost { - -class FederatedServerTest : public BaseFederatedTest { - public: - static void VerifyAllgather(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - CheckAllgather(client, rank); - } - - static void VerifyAllgatherV(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - CheckAllgatherV(client, rank); - } - - static void VerifyAllreduce(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - CheckAllreduce(client); - } - - static void VerifyBroadcast(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - CheckBroadcast(client, rank); - } - - static void VerifyMixture(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - for (auto i = 0; i < 10; i++) { - CheckAllgather(client, rank); - CheckAllreduce(client); - CheckBroadcast(client, rank); - } - } - - protected: - static void CheckAllgather(federated::FederatedClient& client, int rank) { - int data[] = {rank}; - std::string send_buffer(reinterpret_cast(data), sizeof(data)); - auto reply = client.Allgather(send_buffer); - auto const* result = reinterpret_cast(reply.data()); - for (auto i = 0; i < kWorldSize; i++) { - EXPECT_EQ(result[i], i); - } - } - - static void CheckAllgatherV(federated::FederatedClient& client, int rank) { - std::vector inputs{"Hello,", " World!"}; - auto reply = client.AllgatherV(inputs[rank]); - EXPECT_EQ(reply, "Hello, World!"); - } - - static void CheckAllreduce(federated::FederatedClient& client) { - int data[] = {1, 2, 3, 4, 5}; - std::string send_buffer(reinterpret_cast(data), sizeof(data)); - auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM); - auto const* result = reinterpret_cast(reply.data()); - int expected[] = {2, 4, 6, 8, 10}; - for (auto i = 0; i < 5; i++) { - EXPECT_EQ(result[i], expected[i]); - } - } - - static void CheckBroadcast(federated::FederatedClient& client, int rank) { - std::string send_buffer{}; - if (rank == 0) { - send_buffer = "hello broadcast"; - } - auto reply = client.Broadcast(send_buffer, 0); - EXPECT_EQ(reply, "hello broadcast") << "rank " << rank; - } -}; - -TEST_F(FederatedServerTest, Allgather) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllgather, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedServerTest, AllgatherV) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllgatherV, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedServerTest, Allreduce) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyAllreduce, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedServerTest, Broadcast) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyBroadcast, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -TEST_F(FederatedServerTest, Mixture) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(&FederatedServerTest::VerifyMixture, rank, server_->Address()); - } - for (auto& thread : threads) { - thread.join(); - } -} - -} // namespace xgboost diff --git a/tests/cpp/plugin/test_sycl_ghist_builder.cc b/tests/cpp/plugin/test_sycl_ghist_builder.cc new file mode 100644 index 000000000000..dacbc75fc3d5 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_ghist_builder.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020-2024 by XGBoost contributors + */ +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix +#pragma GCC diagnostic pop + +#include "../../../plugin/sycl/common/hist_util.h" +#include "../../../plugin/sycl/device_manager.h" +#include "sycl_helpers.h" +#include "../helpers.h" + +namespace xgboost::sycl::common { + +template +void GHistBuilderTest(float sparsity, bool force_atomic_use) { + const size_t num_rows = 8; + const size_t num_columns = 1; + const int n_bins = 2; + const GradientSumT eps = 1e-6; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); + sycl::DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + + GHistIndexMatrix gmat_sycl; + gmat_sycl.Init(qu, &ctx, dmat, n_bins); + + xgboost::GHistIndexMatrix gmat{&ctx, p_fmat.get(), n_bins, 0.3, false}; + + RowSetCollection row_set_collection; + auto& row_indices = row_set_collection.Data(); + row_indices.Resize(&qu, num_rows); + size_t* p_row_indices = row_indices.Data(); + + qu.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(num_rows), + [p_row_indices](::sycl::item<1> pid) { + const size_t idx = pid.get_id(0); + p_row_indices[idx] = idx; + }); + }).wait_and_throw(); + row_set_collection.Init(); + + auto builder = GHistBuilder(qu, n_bins); + + std::vector gpair = { + {0.1f, 0.2f}, {0.3f, 0.4f}, {0.5f, 0.6f}, {0.7f, 0.8f}, + {0.9f, 0.1f}, {0.2f, 0.3f}, {0.4f, 0.5f}, {0.6f, 0.7f}}; + CHECK_EQ(gpair.size(), num_rows); + USMVector gpair_device(&qu, gpair); + + std::vector hist_host(2*n_bins); + GHistRow hist(&qu, 2 * n_bins); + ::sycl::event event; + + const size_t nblocks = 2; + GHistRow hist_buffer(&qu, 2 * nblocks * n_bins); + + InitHist(qu, &hist, hist.Size(), &event); + InitHist(qu, &hist_buffer, hist_buffer.Size(), &event); + + event = builder.BuildHist(gpair_device, row_set_collection[0], gmat_sycl, &hist, + sparsity < eps , &hist_buffer, event, force_atomic_use); + qu.memcpy(hist_host.data(), hist.Data(), + 2 * n_bins * sizeof(GradientSumT), event); + qu.wait_and_throw(); + + // Build hist on host to compare + std::vector hist_desired(2*n_bins); + for (size_t rid = 0; rid < num_rows; ++rid) { + const size_t ibegin = gmat.row_ptr[rid]; + const size_t iend = gmat.row_ptr[rid + 1]; + for (size_t i = ibegin; i < iend; ++i) { + const size_t bin_idx = gmat.index[i]; + hist_desired[2*bin_idx] += gpair[rid].GetGrad(); + hist_desired[2*bin_idx+1] += gpair[rid].GetHess(); + } + } + + VerifySyclVector(hist_host, hist_desired, eps); +} + +template +void GHistSubtractionTest() { + const size_t n_bins = 4; + using GHistType = GHistRow; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + ::sycl::event event; + std::vector hist1_host = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}; + GHistType hist1(&qu, 2 * n_bins); + event = qu.memcpy(hist1.Data(), hist1_host.data(), + 2 * n_bins * sizeof(GradientSumT), event); + + std::vector hist2_host = {0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1}; + GHistType hist2(&qu, 2 * n_bins); + event = qu.memcpy(hist2.Data(), hist2_host.data(), + 2 * n_bins * sizeof(GradientSumT), event); + + std::vector hist3_host(2 * n_bins); + GHistType hist3(&qu, 2 * n_bins); + event = SubtractionHist(qu, &hist3, hist1, hist2, n_bins, event); + qu.memcpy(hist3_host.data(), hist3.Data(), + 2 * n_bins * sizeof(GradientSumT), event); + qu.wait_and_throw(); + + std::vector hist3_desired(2 * n_bins); + for (size_t idx = 0; idx < 2 * n_bins; ++idx) { + hist3_desired[idx] = hist1_host[idx] - hist2_host[idx]; + } + + const GradientSumT eps = 1e-6; + VerifySyclVector(hist3_host, hist3_desired, eps); +} + +TEST(SyclGHistBuilder, ByBlockDenseCase) { + GHistBuilderTest(0.0, false); + GHistBuilderTest(0.0, false); +} + +TEST(SyclGHistBuilder, ByBlockSparseCase) { + GHistBuilderTest(0.3, false); + GHistBuilderTest(0.3, false); +} + +TEST(SyclGHistBuilder, ByAtomicDenseCase) { + GHistBuilderTest(0.0, true); + GHistBuilderTest(0.0, true); +} + +TEST(SyclGHistBuilder, ByAtomicSparseCase) { + GHistBuilderTest(0.3, true); + GHistBuilderTest(0.3, true); +} + +TEST(SyclGHistBuilder, Subtraction) { + GHistSubtractionTest(); + GHistSubtractionTest(); +} + +} // namespace xgboost::sycl::common diff --git a/tests/cpp/plugin/test_sycl_gradient_index.cc b/tests/cpp/plugin/test_sycl_gradient_index.cc index 35fc7fbbe345..4d605ce7aabe 100644 --- a/tests/cpp/plugin/test_sycl_gradient_index.cc +++ b/tests/cpp/plugin/test_sycl_gradient_index.cc @@ -49,7 +49,8 @@ TEST(SyclGradientIndex, Init) { auto p_fmat = RandomDataGenerator{n_rows, n_columns, 0.3}.GenerateDMatrix(); - sycl::DeviceMatrix dmat(qu, p_fmat.get()); + sycl::DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); int max_bins = 256; common::GHistIndexMatrix gmat_sycl; diff --git a/tests/cpp/plugin/test_sycl_hist_updater.cc b/tests/cpp/plugin/test_sycl_hist_updater.cc new file mode 100644 index 000000000000..1ef771a0c7ec --- /dev/null +++ b/tests/cpp/plugin/test_sycl_hist_updater.cc @@ -0,0 +1,349 @@ +/** + * Copyright 2020-2024 by XGBoost contributors + */ +#include + +#include + +#include "../../../plugin/sycl/tree/hist_updater.h" +#include "../../../plugin/sycl/device_manager.h" + +#include "../helpers.h" + +namespace xgboost::sycl::tree { + +// Use this class to test the protected methods of HistUpdater +template +class TestHistUpdater : public HistUpdater { + public: + TestHistUpdater(const Context* ctx, + ::sycl::queue qu, + const xgboost::tree::TrainParam& param, + std::unique_ptr pruner, + FeatureInteractionConstraintHost int_constraints_, + DMatrix const* fmat) : HistUpdater(ctx, qu, param, + std::move(pruner), + int_constraints_, fmat) {} + + void TestInitSampling(const USMVector &gpair, + USMVector* row_indices) { + HistUpdater::InitSampling(gpair, row_indices); + } + + auto* TestInitData(const common::GHistIndexMatrix& gmat, + const USMVector &gpair, + const DMatrix& fmat, + const RegTree& tree) { + HistUpdater::InitData(gmat, gpair, fmat, tree); + return &(HistUpdater::row_set_collection_); + } + + const auto* TestBuildHistogramsLossGuide(ExpandEntry entry, + const common::GHistIndexMatrix &gmat, + RegTree *p_tree, + const USMVector &gpair) { + HistUpdater::BuildHistogramsLossGuide(entry, gmat, p_tree, gpair); + return &(HistUpdater::hist_); + } + + auto TestInitNewNode(int nid, + const common::GHistIndexMatrix& gmat, + const USMVector &gpair, + const DMatrix& fmat, + const RegTree& tree) { + HistUpdater::InitNewNode(nid, gmat, gpair, fmat, tree); + return HistUpdater::snode_host_[nid]; + } +}; + +void GenerateRandomGPairs(::sycl::queue* qu, GradientPair* gpair_ptr, size_t num_rows, bool has_neg_hess) { + qu->submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)), + [=](::sycl::item<1> pid) { + uint64_t i = pid.get_linear_id(); + + constexpr uint32_t seed = 777; + oneapi::dpl::minstd_rand engine(seed, i); + GradientPair::ValueT smallest_hess_val = has_neg_hess ? -1. : 0.; + oneapi::dpl::uniform_real_distribution distr(smallest_hess_val, 1.); + gpair_ptr[i] = {distr(engine), distr(engine)}; + }); + }); + qu->wait(); +} + +template +void TestHistUpdaterSampling(const xgboost::tree::TrainParam& param) { + const size_t num_rows = 1u << 12; + const size_t num_columns = 1; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + ObjInfo task{ObjInfo::kRegression}; + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix(); + + FeatureInteractionConstraintHost int_constraints; + std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; + + TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + + USMVector row_indices_0(&qu, num_rows); + USMVector row_indices_1(&qu, num_rows); + USMVector gpair(&qu, num_rows); + GenerateRandomGPairs(&qu, gpair.Data(), num_rows, true); + + updater.TestInitSampling(gpair, &row_indices_0); + + size_t n_samples = row_indices_0.Size(); + // Half of gpairs have neg hess + ASSERT_LT(n_samples, num_rows * 0.5 * param.subsample * 1.2); + ASSERT_GT(n_samples, num_rows * 0.5 * param.subsample / 1.2); + + // Check if two lanunches generate different realisations: + updater.TestInitSampling(gpair, &row_indices_1); + if (row_indices_1.Size() == n_samples) { + std::vector row_indices_0_host(n_samples); + std::vector row_indices_1_host(n_samples); + qu.memcpy(row_indices_0_host.data(), row_indices_0.Data(), n_samples * sizeof(size_t)).wait(); + qu.memcpy(row_indices_1_host.data(), row_indices_1.Data(), n_samples * sizeof(size_t)).wait(); + + // The order in row_indices_0 and row_indices_1 can be different + std::set rows; + for (auto row : row_indices_0_host) { + rows.insert(row); + } + + size_t num_diffs = 0; + for (auto row : row_indices_1_host) { + if (rows.count(row) == 0) num_diffs++; + } + + ASSERT_NE(num_diffs, 0); + } + +} + +template +void TestHistUpdaterInitData(const xgboost::tree::TrainParam& param, bool has_neg_hess) { + const size_t num_rows = 1u << 8; + const size_t num_columns = 1; + const size_t n_bins = 32; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + ObjInfo task{ObjInfo::kRegression}; + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, 0.0}.GenerateDMatrix(); + + FeatureInteractionConstraintHost int_constraints; + std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; + + TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + + USMVector gpair(&qu, num_rows); + GenerateRandomGPairs(&qu, gpair.Data(), num_rows, has_neg_hess); + + DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + common::GHistIndexMatrix gmat; + gmat.Init(qu, &ctx, dmat, n_bins); + RegTree tree; + + auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); + auto& row_indices = row_set_collection->Data(); + + std::vector row_indices_host(row_indices.Size()); + qu.memcpy(row_indices_host.data(), row_indices.DataConst(), row_indices.Size()*sizeof(size_t)).wait(); + + if (!has_neg_hess) { + for (size_t i = 0; i < num_rows; ++i) { + ASSERT_EQ(row_indices_host[i], i); + } + } else { + std::vector gpair_host(num_rows); + qu.memcpy(gpair_host.data(), gpair.Data(), num_rows*sizeof(GradientPair)).wait(); + + std::set rows; + for (size_t i = 0; i < num_rows; ++i) { + if (gpair_host[i].GetHess() >= 0.0f) { + rows.insert(i); + } + } + ASSERT_EQ(rows.size(), row_indices_host.size()); + for (size_t row_idx : row_indices_host) { + ASSERT_EQ(rows.count(row_idx), 1); + } + } +} + +template +void TestHistUpdaterBuildHistogramsLossGuide(const xgboost::tree::TrainParam& param, float sparsity) { + const size_t num_rows = 1u << 8; + const size_t num_columns = 1; + const size_t n_bins = 32; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + ObjInfo task{ObjInfo::kRegression}; + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); + + FeatureInteractionConstraintHost int_constraints; + std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; + + TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + updater.SetHistSynchronizer(new BatchHistSynchronizer()); + updater.SetHistRowsAdder(new BatchHistRowsAdder()); + + USMVector gpair(&qu, num_rows); + auto* gpair_ptr = gpair.Data(); + GenerateRandomGPairs(&qu, gpair_ptr, num_rows, false); + + DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + common::GHistIndexMatrix gmat; + gmat.Init(qu, &ctx, dmat, n_bins); + + RegTree tree; + tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); + tree.ExpandNode(tree[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); + tree.ExpandNode(tree[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0); + + ExpandEntry node0(0, tree.GetDepth(0)); + ExpandEntry node1(1, tree.GetDepth(1)); + ExpandEntry node2(2, tree.GetDepth(2)); + + auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); + row_set_collection->AddSplit(0, 1, 2, 42, num_rows - 42); + + updater.TestBuildHistogramsLossGuide(node0, gmat, &tree, gpair); + const auto* hist = updater.TestBuildHistogramsLossGuide(node1, gmat, &tree, gpair); + + ASSERT_EQ((*hist)[0].Size(), n_bins); + ASSERT_EQ((*hist)[1].Size(), n_bins); + ASSERT_EQ((*hist)[2].Size(), n_bins); + + std::vector> hist0_host(n_bins); + std::vector> hist1_host(n_bins); + std::vector> hist2_host(n_bins); + qu.memcpy(hist0_host.data(), (*hist)[0].DataConst(), sizeof(xgboost::detail::GradientPairInternal) * n_bins); + qu.memcpy(hist1_host.data(), (*hist)[1].DataConst(), sizeof(xgboost::detail::GradientPairInternal) * n_bins); + qu.memcpy(hist2_host.data(), (*hist)[2].DataConst(), sizeof(xgboost::detail::GradientPairInternal) * n_bins); + qu.wait(); + + for (size_t idx_bin = 0; idx_bin < n_bins; ++idx_bin) { + EXPECT_NEAR(hist0_host[idx_bin].GetGrad(), hist1_host[idx_bin].GetGrad() + hist2_host[idx_bin].GetGrad(), 1e-6); + EXPECT_NEAR(hist0_host[idx_bin].GetHess(), hist1_host[idx_bin].GetHess() + hist2_host[idx_bin].GetHess(), 1e-6); + } +} + +template +void TestHistUpdaterInitNewNode(const xgboost::tree::TrainParam& param, float sparsity) { + const size_t num_rows = 1u << 8; + const size_t num_columns = 1; + const size_t n_bins = 32; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + ObjInfo task{ObjInfo::kRegression}; + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); + + FeatureInteractionConstraintHost int_constraints; + std::unique_ptr pruner{TreeUpdater::Create("prune", &ctx, &task)}; + + TestHistUpdater updater(&ctx, qu, param, std::move(pruner), int_constraints, p_fmat.get()); + updater.SetHistSynchronizer(new BatchHistSynchronizer()); + updater.SetHistRowsAdder(new BatchHistRowsAdder()); + + USMVector gpair(&qu, num_rows); + auto* gpair_ptr = gpair.Data(); + GenerateRandomGPairs(&qu, gpair_ptr, num_rows, false); + + DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + common::GHistIndexMatrix gmat; + gmat.Init(qu, &ctx, dmat, n_bins); + + RegTree tree; + tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); + ExpandEntry node(ExpandEntry::kRootNid, tree.GetDepth(ExpandEntry::kRootNid)); + + auto* row_set_collection = updater.TestInitData(gmat, gpair, *p_fmat, tree); + auto& row_idxs = row_set_collection->Data(); + const size_t* row_idxs_ptr = row_idxs.DataConst(); + updater.TestBuildHistogramsLossGuide(node, gmat, &tree, gpair); + const auto snode = updater.TestInitNewNode(ExpandEntry::kRootNid, gmat, gpair, *p_fmat, tree); + + GradStats grad_stat; + { + ::sycl::buffer> buff(&grad_stat, 1); + qu.submit([&](::sycl::handler& cgh) { + auto buff_acc = buff.template get_access<::sycl::access::mode::read_write>(cgh); + cgh.single_task<>([=]() { + for (size_t i = 0; i < num_rows; ++i) { + size_t row_idx = row_idxs_ptr[i]; + buff_acc[0] += GradStats(gpair_ptr[row_idx].GetGrad(), + gpair_ptr[row_idx].GetHess()); + } + }); + }).wait_and_throw(); + } + + EXPECT_NEAR(snode.stats.GetGrad(), grad_stat.GetGrad(), 1e-6 * grad_stat.GetGrad()); + EXPECT_NEAR(snode.stats.GetHess(), grad_stat.GetHess(), 1e-6 * grad_stat.GetHess()); +} + +TEST(SyclHistUpdater, Sampling) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"subsample", "0.7"}}); + + TestHistUpdaterSampling(param); + TestHistUpdaterSampling(param); +} + +TEST(SyclHistUpdater, InitData) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"subsample", "1"}}); + + TestHistUpdaterInitData(param, true); + TestHistUpdaterInitData(param, false); + + TestHistUpdaterInitData(param, true); + TestHistUpdaterInitData(param, false); +} + +TEST(SyclHistUpdater, BuildHistogramsLossGuide) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_depth", "3"}}); + + TestHistUpdaterBuildHistogramsLossGuide(param, 0.0); + TestHistUpdaterBuildHistogramsLossGuide(param, 0.5); + TestHistUpdaterBuildHistogramsLossGuide(param, 0.0); + TestHistUpdaterBuildHistogramsLossGuide(param, 0.5); +} + +TEST(SyclHistUpdater, InitNewNode) { + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"max_depth", "3"}}); + + TestHistUpdaterInitNewNode(param, 0.0); + TestHistUpdaterInitNewNode(param, 0.5); + TestHistUpdaterInitNewNode(param, 0.0); + TestHistUpdaterInitNewNode(param, 0.5); +} + +} // namespace xgboost::sycl::tree diff --git a/tests/cpp/plugin/test_sycl_partition_builder.cc b/tests/cpp/plugin/test_sycl_partition_builder.cc index 90bc757eb1b0..7e3126a79e81 100644 --- a/tests/cpp/plugin/test_sycl_partition_builder.cc +++ b/tests/cpp/plugin/test_sycl_partition_builder.cc @@ -13,6 +13,108 @@ namespace xgboost::sycl::common { +void TestPartitioning(float sparsity, int max_bins) { + const size_t num_rows = 16; + const size_t num_columns = 1; + + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + auto p_fmat = RandomDataGenerator{num_rows, num_columns, sparsity}.GenerateDMatrix(); + sycl::DeviceMatrix dmat; + dmat.Init(qu, p_fmat.get()); + + common::GHistIndexMatrix gmat; + gmat.Init(qu, &ctx, dmat, max_bins); + + RowSetCollection row_set_collection; + auto& row_indices = row_set_collection.Data(); + row_indices.Resize(&qu, num_rows); + size_t* p_row_indices = row_indices.Data(); + + qu.submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(num_rows), + [p_row_indices](::sycl::item<1> pid) { + const size_t idx = pid.get_id(0); + p_row_indices[idx] = idx; + }); + }).wait_and_throw(); + row_set_collection.Init(); + + RegTree tree; + tree.ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0); + + const size_t n_nodes = row_set_collection.Size(); + PartitionBuilder partition_builder; + partition_builder.Init(&qu, n_nodes, [&](size_t nid) { + return row_set_collection[nid].Size(); + }); + + std::vector nodes; + nodes.emplace_back(tree::ExpandEntry(0, tree.GetDepth(0))); + + ::sycl::event event; + std::vector split_conditions = {2}; + partition_builder.Partition(gmat, nodes, row_set_collection, + split_conditions, &tree, &event); + qu.wait_and_throw(); + + size_t* data_result = const_cast(row_set_collection[0].begin); + partition_builder.MergeToArray(0, data_result, &event); + qu.wait_and_throw(); + + bst_float split_pt = gmat.cut.Values()[split_conditions[0]]; + + std::vector ridx_left(num_rows, 0); + std::vector ridx_right(num_rows, 0); + for (auto &batch : gmat.p_fmat->GetBatches()) { + const auto& data_vec = batch.data.HostVector(); + const auto& offset_vec = batch.offset.HostVector(); + + size_t begin = offset_vec[0]; + for (size_t idx = 0; idx < offset_vec.size() - 1; ++idx) { + size_t end = offset_vec[idx + 1]; + if (begin < end) { + const auto& entry = data_vec[begin]; + if (entry.fvalue < split_pt) { + ridx_left[idx] = 1; + } else { + ridx_right[idx] = 1; + } + } else { + // missing value + if (tree[0].DefaultLeft()) { + ridx_left[idx] = 1; + } else { + ridx_right[idx] = 1; + } + } + begin = end; + } + } + auto n_left = std::accumulate(ridx_left.begin(), ridx_left.end(), 0); + auto n_right = std::accumulate(ridx_right.begin(), ridx_right.end(), 0); + + std::vector row_indices_host(num_rows); + qu.memcpy(row_indices_host.data(), row_indices.Data(), num_rows * sizeof(size_t)); + qu.wait_and_throw(); + + ASSERT_EQ(n_left, partition_builder.GetNLeftElems(0)); + for (size_t i = 0; i < n_left; ++i) { + auto idx = row_indices_host[i]; + ASSERT_EQ(ridx_left[idx], 1); + } + + ASSERT_EQ(n_right, partition_builder.GetNRightElems(0)); + for (size_t i = 0; i < n_right; ++i) { + auto idx = row_indices_host[num_rows - 1 - i]; + ASSERT_EQ(ridx_right[idx], 1); + } +} + TEST(SyclPartitionBuilder, BasicTest) { constexpr size_t kNodes = 5; // Number of rows for each node @@ -67,7 +169,7 @@ TEST(SyclPartitionBuilder, BasicTest) { std::vector v(*std::max_element(rows.begin(), rows.end())); size_t row_id = 0; for(size_t nid = 0; nid < kNodes; ++nid) { - builder.MergeToArray(nid, v.data(), event); + builder.MergeToArray(nid, v.data(), &event); qu.wait(); // Check that row_id for left side are correct @@ -88,4 +190,20 @@ TEST(SyclPartitionBuilder, BasicTest) { } } +TEST(SyclPartitionBuilder, PartitioningSparce) { + TestPartitioning(0.3, 256); +} + +TEST(SyclPartitionBuilder, PartitioningDence8Bits) { + TestPartitioning(0.0, 256); +} + +TEST(SyclPartitionBuilder, PartitioningDence16Bits) { + TestPartitioning(0.0, 256 + 1); +} + +TEST(SyclPartitionBuilder, PartitioningDence32Bits) { + TestPartitioning(0.0, (1u << 16) + 1); +} + } // namespace xgboost::common diff --git a/tests/cpp/plugin/test_sycl_predictor.cc b/tests/cpp/plugin/test_sycl_predictor.cc index d5b3a5e5cd9a..7bd788a3b071 100755 --- a/tests/cpp/plugin/test_sycl_predictor.cc +++ b/tests/cpp/plugin/test_sycl_predictor.cc @@ -43,7 +43,7 @@ TEST(SyclPredictor, ExternalMemory) { } TEST(SyclPredictor, InplacePredict) { - bst_row_t constexpr kRows{128}; + bst_idx_t constexpr kRows{128}; bst_feature_t constexpr kCols{64}; Context ctx; auto gen = RandomDataGenerator{kRows, kCols, 0.5}.Device(ctx.Device()); @@ -106,4 +106,4 @@ TEST(SyclPredictor, Multi) { TestVectorLeafPrediction(&ctx); } -} // namespace xgboost \ No newline at end of file +} // namespace xgboost diff --git a/tests/cpp/plugin/test_sycl_quantile_hist_builder.cc b/tests/cpp/plugin/test_sycl_quantile_hist_builder.cc new file mode 100644 index 000000000000..4bf7bd962750 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_quantile_hist_builder.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020-2024 by XGBoost contributors + */ +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include +#include +#include "../../../plugin/sycl/tree/updater_quantile_hist.h" // for QuantileHistMaker +#pragma GCC diagnostic pop + +namespace xgboost::sycl::tree { +TEST(SyclQuantileHistMaker, Basic) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + ObjInfo task{ObjInfo::kRegression}; + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; + + ASSERT_EQ(updater->Name(), "grow_quantile_histmaker_sycl"); +} + +TEST(SyclQuantileHistMaker, JsonIO) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + ObjInfo task{ObjInfo::kRegression}; + Json config {Object()}; + { + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; + updater->Configure({{"max_depth", std::to_string(42)}}); + updater->Configure({{"single_precision_histogram", std::to_string(true)}}); + updater->SaveConfig(&config); + } + + { + std::unique_ptr updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)}; + updater->LoadConfig(config); + + Json new_config {Object()}; + updater->SaveConfig(&new_config); + + ASSERT_EQ(config, new_config); + + auto max_depth = atoi(get(new_config["train_param"]["max_depth"]).c_str()); + ASSERT_EQ(max_depth, 42); + + auto single_precision_histogram = atoi(get(new_config["sycl_hist_train_param"]["single_precision_histogram"]).c_str()); + ASSERT_EQ(single_precision_histogram, 1); + } + +} +} // namespace xgboost::sycl::tree diff --git a/tests/cpp/plugin/test_sycl_split_evaluator.cc b/tests/cpp/plugin/test_sycl_split_evaluator.cc new file mode 100644 index 000000000000..507490fd17e1 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_split_evaluator.cc @@ -0,0 +1,134 @@ +/** + * Copyright 2020-2024 by XGBoost contributors + */ +#include +#include + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include "../../../plugin/sycl/tree/split_evaluator.h" +#pragma GCC diagnostic pop + +#include "../../../plugin/sycl/device_manager.h" +#include "../helpers.h" + +namespace xgboost::sycl::tree { + +template +void BasicTestSplitEvaluator(const std::string& monotone_constraints, bool has_constrains) { + const size_t n_columns = 2; + + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, + {"reg_lambda", "0"}, + {"monotone_constraints", monotone_constraints}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(DeviceOrd::SyclDefault()); + + TreeEvaluator tree_evaluator(qu, param, n_columns); + { + // Check correctness of has_constrains flag + ASSERT_EQ(tree_evaluator.HasConstraint(), has_constrains); + } + + auto split_evaluator = tree_evaluator.GetEvaluator(); + { + // Check if params were inititialised correctly + ASSERT_EQ(split_evaluator.param.min_child_weight, param.min_child_weight); + ASSERT_EQ(split_evaluator.param.reg_lambda, param.reg_lambda); + ASSERT_EQ(split_evaluator.param.reg_alpha, param.reg_alpha); + ASSERT_EQ(split_evaluator.param.max_delta_step, param.max_delta_step); + } +} + +template +void TestSplitEvaluator(const std::string& monotone_constraints) { + const size_t n_columns = 2; + + xgboost::tree::TrainParam param; + param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, + {"reg_lambda", "0"}, + {"monotone_constraints", monotone_constraints}}); + + DeviceManager device_manager; + auto qu = device_manager.GetQueue(DeviceOrd::SyclDefault()); + + TreeEvaluator tree_evaluator(qu, param, n_columns); + auto split_evaluator = tree_evaluator.GetEvaluator(); + { + // Test ThresholdL1 + const GradientSumT alpha = 0.5; + { + const GradientSumT val = 0.0; + const auto trh = split_evaluator.ThresholdL1(val, alpha); + ASSERT_EQ(trh, 0.0); + } + + { + const GradientSumT val = 1.0; + const auto trh = split_evaluator.ThresholdL1(val, alpha); + ASSERT_EQ(trh, val - alpha); + } + + { + const GradientSumT val = -1.0; + const auto trh = split_evaluator.ThresholdL1(val, alpha); + ASSERT_EQ(trh, val + alpha); + } + } + + { + constexpr float eps = 1e-8; + tree_evaluator.AddSplit(0, 1, 2, 0, 0.3, 0.7); + + GradStats left(0.1, 0.2); + GradStats right(0.3, 0.4); + bst_node_t nidx = 0; + bst_feature_t fidx = 0; + + GradientSumT wleft = split_evaluator.CalcWeight(nidx, left); + // wleft = -grad/hess = -0.1/0.2 + EXPECT_NEAR(wleft, -0.5, eps); + GradientSumT wright = split_evaluator.CalcWeight(nidx, right); + // wright = -grad/hess = -0.3/0.4 + EXPECT_NEAR(wright, -0.75, eps); + + GradientSumT gweight_left = split_evaluator.CalcGainGivenWeight(nidx, left, wleft); + // gweight_left = left.grad**2 / left.hess = 0.1*0.1/0.2 = 0.05 + EXPECT_NEAR(gweight_left, 0.05, eps); + // gweight_left = right.grad**2 / right.hess = 0.3*0.3/0.4 = 0.225 + GradientSumT gweight_right = split_evaluator.CalcGainGivenWeight(nidx, right, wright); + EXPECT_NEAR(gweight_right, 0.225, eps); + + GradientSumT split_gain = split_evaluator.CalcSplitGain(nidx, fidx, left, right); + if (!tree_evaluator.HasConstraint()) { + EXPECT_NEAR(split_gain, gweight_left + gweight_right, eps); + } else { + // Parameters are chosen to have -inf here + ASSERT_EQ(split_gain, -std::numeric_limits::infinity()); + } + } +} + +TEST(SyclSplitEvaluator, BasicTest) { + BasicTestSplitEvaluator("( 0, 0)", false); + BasicTestSplitEvaluator("( 1, 0)", true); + BasicTestSplitEvaluator("( 0, 1)", true); + BasicTestSplitEvaluator("(-1, 0)", true); + BasicTestSplitEvaluator("( 0, -1)", true); + BasicTestSplitEvaluator("( 1, 1)", true); + BasicTestSplitEvaluator("(-1, -1)", true); + BasicTestSplitEvaluator("( 1, -1)", true); + BasicTestSplitEvaluator("(-1, 1)", true); +} + +TEST(SyclSplitEvaluator, TestMath) { + // Without constraints + TestSplitEvaluator("( 0, 0)"); + // With constraints + TestSplitEvaluator("( 1, 0)"); +} + +} // namespace xgboost::sycl::tree diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 669827ee4e92..637b77b25041 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -12,6 +12,7 @@ #include "../../../src/data/proxy_dmatrix.h" #include "../../../src/gbm/gbtree.h" #include "../../../src/gbm/gbtree_model.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" #include "test_predictor.h" @@ -43,7 +44,7 @@ void TestColumnSplit() { TEST(CpuPredictor, BasicColumnSplit) { auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, TestColumnSplit); + collective::TestDistributedGlobal(kWorldSize, TestColumnSplit); } TEST(CpuPredictor, IterationRange) { @@ -65,7 +66,7 @@ TEST(CpuPredictor, ExternalMemory) { } TEST(CpuPredictor, InplacePredict) { - bst_row_t constexpr kRows{128}; + bst_idx_t constexpr kRows{128}; bst_feature_t constexpr kCols{64}; Context ctx; auto gen = RandomDataGenerator{kRows, kCols, 0.5}.Device(ctx.Device()); @@ -83,7 +84,7 @@ TEST(CpuPredictor, InplacePredict) { { HostDeviceVector data; - HostDeviceVector rptrs; + HostDeviceVector rptrs; HostDeviceVector columns; gen.GenerateCSR(&data, &rptrs, &columns); auto data_interface = GetArrayInterface(&data, kRows * kCols, 1); @@ -157,7 +158,7 @@ TEST(CPUPredictor, CategoricalPrediction) { TEST(CPUPredictor, CategoricalPredictionColumnSplit) { auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, false, true); + collective::TestDistributedGlobal(kWorldSize, [] { TestCategoricalPrediction(false, true); }); } TEST(CPUPredictor, CategoricalPredictLeaf) { @@ -168,7 +169,7 @@ TEST(CPUPredictor, CategoricalPredictLeaf) { TEST(CPUPredictor, CategoricalPredictLeafColumnSplit) { auto constexpr kWorldSize = 2; Context ctx; - RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, &ctx, true); + collective::TestDistributedGlobal(kWorldSize, [&] { TestCategoricalPredictLeaf(&ctx, true); }); } TEST(CpuPredictor, UpdatePredictionCache) { @@ -183,7 +184,8 @@ TEST(CpuPredictor, LesserFeatures) { TEST(CpuPredictor, LesserFeaturesColumnSplit) { auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, TestPredictionWithLesserFeaturesColumnSplit, false); + collective::TestDistributedGlobal(kWorldSize, + [] { TestPredictionWithLesserFeaturesColumnSplit(false); }); } TEST(CpuPredictor, Sparse) { diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 50e036b90794..4895fb63fb79 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -12,6 +12,7 @@ #include "../../../src/data/device_adapter.cuh" #include "../../../src/data/proxy_dmatrix.h" #include "../../../src/gbm/gbtree_model.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal, BaseMGPUTest #include "../helpers.h" #include "test_predictor.h" @@ -85,7 +86,7 @@ void VerifyBasicColumnSplit(std::array, 32> const& expected_r } } // anonymous namespace -class MGPUPredictorTest : public BaseMGPUTest {}; +class MGPUPredictorTest : public collective::BaseMGPUTest {}; TEST_F(MGPUPredictorTest, BasicColumnSplit) { auto ctx = MakeCUDACtx(0); @@ -111,7 +112,8 @@ TEST_F(MGPUPredictorTest, BasicColumnSplit) { result[i - 1] = out_predictions_h; } - DoTest(VerifyBasicColumnSplit, result); + this->DoTest([&] { VerifyBasicColumnSplit(result); }, true); + this->DoTest([&] { VerifyBasicColumnSplit(result); }, false); } TEST(GPUPredictor, EllpackBasic) { @@ -209,7 +211,8 @@ TEST(GpuPredictor, LesserFeatures) { } TEST_F(MGPUPredictorTest, LesserFeaturesColumnSplit) { - RunWithInMemoryCommunicator(world_size_, TestPredictionWithLesserFeaturesColumnSplit, true); + this->DoTest([] { TestPredictionWithLesserFeaturesColumnSplit(true); }, true); + this->DoTest([] { TestPredictionWithLesserFeaturesColumnSplit(true); }, false); } // Very basic test of empty model @@ -277,7 +280,7 @@ TEST(GPUPredictor, IterationRange) { } TEST_F(MGPUPredictorTest, IterationRangeColumnSplit) { - TestIterationRangeColumnSplit(world_size_, true); + TestIterationRangeColumnSplit(common::AllVisibleGPUs(), true); } TEST(GPUPredictor, CategoricalPrediction) { @@ -285,7 +288,8 @@ TEST(GPUPredictor, CategoricalPrediction) { } TEST_F(MGPUPredictorTest, CategoricalPredictionColumnSplit) { - RunWithInMemoryCommunicator(world_size_, TestCategoricalPrediction, true, true); + this->DoTest([] { TestCategoricalPrediction(true, true); }, true); + this->DoTest([] { TestCategoricalPrediction(true, true); }, false); } TEST(GPUPredictor, CategoricalPredictLeaf) { @@ -294,8 +298,18 @@ TEST(GPUPredictor, CategoricalPredictLeaf) { } TEST_F(MGPUPredictorTest, CategoricalPredictionLeafColumnSplit) { - auto ctx = MakeCUDACtx(common::AllVisibleGPUs() == 1 ? 0 : collective::GetRank()); - RunWithInMemoryCommunicator(world_size_, TestCategoricalPredictLeaf, &ctx, true); + this->DoTest( + [&] { + auto ctx = MakeCUDACtx(collective::GetRank()); + TestCategoricalPredictLeaf(&ctx, true); + }, + true); + this->DoTest( + [&] { + auto ctx = MakeCUDACtx(collective::GetRank()); + TestCategoricalPredictLeaf(&ctx, true); + }, + false); } TEST(GPUPredictor, PredictLeafBasic) { @@ -325,7 +339,7 @@ TEST(GPUPredictor, Sparse) { } TEST_F(MGPUPredictorTest, SparseColumnSplit) { - TestSparsePredictionColumnSplit(world_size_, true, 0.2); - TestSparsePredictionColumnSplit(world_size_, true, 0.8); + TestSparsePredictionColumnSplit(common::AllVisibleGPUs(), true, 0.2); + TestSparsePredictionColumnSplit(common::AllVisibleGPUs(), true, 0.8); } } // namespace xgboost::predictor diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 0d715760853b..fde0e480b8cd 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2023 by XGBoost Contributors + * Copyright 2020-2024, XGBoost Contributors */ #include "test_predictor.h" @@ -10,7 +10,6 @@ #include // for PredictionCacheEntry, Predictor, Predic... #include // for StringView -#include // for max #include // for numeric_limits #include // for shared_ptr #include // for unordered_map @@ -18,6 +17,7 @@ #include "../../../src/common/bitfield.h" // for LBitField32 #include "../../../src/data/iterative_dmatrix.h" // for IterativeDMatrix #include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" // for GetDMatrixFromData, RandomDataGenerator #include "xgboost/json.h" // for Json, Object, get, String #include "xgboost/linalg.h" // for MakeVec, Tensor, TensorView, Vector @@ -186,7 +186,7 @@ void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins, } } -void TestInplacePrediction(Context const *ctx, std::shared_ptr x, bst_row_t rows, +void TestInplacePrediction(Context const *ctx, std::shared_ptr x, bst_idx_t rows, bst_feature_t cols) { std::size_t constexpr kClasses { 4 }; auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(ctx->Device()); @@ -255,7 +255,7 @@ std::unique_ptr LearnerForTest(Context const *ctx, std::shared_ptr m_test, std::shared_ptr m_invalid) { HostDeviceVector prediction; @@ -593,9 +593,23 @@ void TestIterationRangeColumnSplit(int world_size, bool use_gpu) { Json sliced_model{Object{}}; sliced->SaveModel(&sliced_model); - RunWithInMemoryCommunicator(world_size, VerifyIterationRangeColumnSplit, use_gpu, ranged_model, - sliced_model, kRows, kCols, kClasses, margin_ranged, margin_sliced, - leaf_ranged, leaf_sliced); +#if !defined(XGBOOST_USE_NCCL) + if (use_gpu) { + GTEST_SKIP_("Not compiled with NCCL"); + return; + } +#endif // defined(XGBOOST_USE_NCCL) + collective::TestDistributedGlobal(world_size, [&] { + VerifyIterationRangeColumnSplit(use_gpu, ranged_model, sliced_model, kRows, kCols, kClasses, + margin_ranged, margin_sliced, leaf_ranged, leaf_sliced); + }); + +#if defined(XGBOOST_USE_FEDERATED) + collective::TestFederatedGlobal(world_size, [&] { + VerifyIterationRangeColumnSplit(use_gpu, ranged_model, sliced_model, kRows, kCols, kClasses, + margin_ranged, margin_sliced, leaf_ranged, leaf_sliced); + }); +#endif // defined(XGBOOST_USE_FEDERATED) } void TestSparsePrediction(Context const *ctx, float sparsity) { @@ -701,8 +715,23 @@ void TestSparsePredictionColumnSplit(int world_size, bool use_gpu, float sparsit learner->SetParam("device", ctx.DeviceName()); learner->Predict(Xy, false, &sparse_predt, 0, 0); - RunWithInMemoryCommunicator(world_size, VerifySparsePredictionColumnSplit, use_gpu, model, - kRows, kCols, sparsity, sparse_predt.HostVector()); +#if !defined(XGBOOST_USE_NCCL) + if (use_gpu) { + GTEST_SKIP_("Not compiled with NCCL."); + return; + } +#endif // defined(XGBOOST_USE_CUDA) + collective::TestDistributedGlobal(world_size, [&] { + VerifySparsePredictionColumnSplit(use_gpu, model, kRows, kCols, sparsity, + sparse_predt.HostVector()); + }); + +#if defined(XGBOOST_USE_FEDERATED) + collective::TestFederatedGlobal(world_size, [&] { + VerifySparsePredictionColumnSplit(use_gpu, model, kRows, kCols, sparsity, + sparse_predt.HostVector()); + }); +#endif // defined(XGBOOST_USE_FEDERATED) } void TestVectorLeafPrediction(Context const *ctx) { diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index a65b60579e61..1ccd35102b2d 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -92,7 +92,7 @@ void TestTrainingPrediction(Context const* ctx, size_t rows, size_t bins, std::shared_ptr p_full, std::shared_ptr p_hist, bool check_contribs = false); -void TestInplacePrediction(Context const* ctx, std::shared_ptr x, bst_row_t rows, +void TestInplacePrediction(Context const* ctx, std::shared_ptr x, bst_idx_t rows, bst_feature_t cols); void TestPredictionWithLesserFeatures(Context const* ctx); diff --git a/tests/cpp/rabit/allreduce_base_test.cc b/tests/cpp/rabit/allreduce_base_test.cc deleted file mode 100644 index 55cce5c7dfd4..000000000000 --- a/tests/cpp/rabit/allreduce_base_test.cc +++ /dev/null @@ -1,42 +0,0 @@ -#define RABIT_CXXTESTDEFS_H -#if !defined(_WIN32) -#include - -#include -#include -#include "../../../rabit/src/allreduce_base.h" - -TEST(AllreduceBase, InitTask) -{ - rabit::engine::AllreduceBase base; - - std::string rabit_task_id = "rabit_task_id=1"; - char cmd[rabit_task_id.size()+1]; - std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd); - cmd[rabit_task_id.size()] = '\0'; - - char* argv[] = {cmd}; - base.Init(1, argv); - EXPECT_EQ(base.task_id, "1"); -} - -TEST(AllreduceBase, InitWithRingReduce) -{ - rabit::engine::AllreduceBase base; - - std::string rabit_task_id = "rabit_task_id=1"; - char cmd[rabit_task_id.size()+1]; - std::copy(rabit_task_id.begin(), rabit_task_id.end(), cmd); - cmd[rabit_task_id.size()] = '\0'; - - std::string rabit_reduce_ring_mincount = "rabit_reduce_ring_mincount=1"; - char cmd2[rabit_reduce_ring_mincount.size()+1]; - std::copy(rabit_reduce_ring_mincount.begin(), rabit_reduce_ring_mincount.end(), cmd2); - cmd2[rabit_reduce_ring_mincount.size()] = '\0'; - - char* argv[] = {cmd, cmd2}; - base.Init(2, argv); - EXPECT_EQ(base.task_id, "1"); - EXPECT_EQ(base.reduce_ring_mincount, 1ul); -} -#endif // !defined(_WIN32) diff --git a/tests/cpp/rabit/test_utils.cc b/tests/cpp/rabit/test_utils.cc deleted file mode 100644 index 0b8787bdd6b3..000000000000 --- a/tests/cpp/rabit/test_utils.cc +++ /dev/null @@ -1,6 +0,0 @@ -#include -#include - -TEST(Utils, Assert) { - EXPECT_THROW({rabit::utils::Assert(false, "foo");}, dmlc::Error); -} diff --git a/tests/cpp/test_helpers.cc b/tests/cpp/test_helpers.cc index 79d8d2475181..f582ba564b61 100644 --- a/tests/cpp/test_helpers.cc +++ b/tests/cpp/test_helpers.cc @@ -11,7 +11,7 @@ TEST(RandomDataGenerator, DMatrix) { auto p_dmatrix = RandomDataGenerator{kRows, kCols, kSparsity}.GenerateDMatrix(); HostDeviceVector csr_value; - HostDeviceVector csr_rptr; + HostDeviceVector csr_rptr; HostDeviceVector csr_cidx; RandomDataGenerator{kRows, kCols, kSparsity}.GenerateCSR(&csr_value, &csr_rptr, &csr_cidx); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 04f1d35b4002..976ae2147a06 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -1,18 +1,16 @@ /** - * Copyright (c) 2017-2023, XGBoost contributors + * Copyright 2017-2024, XGBoost contributors */ -#include #include -#include // for Learner -#include // for LogCheck_NE, CHECK_NE, LogCheck_EQ -#include // for ObjFunction -#include // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR +#include +#include // for Learner +#include // for LogCheck_NE, CHECK_NE, LogCheck_EQ +#include // for ObjFunction +#include // for XGBOOST_VER_MAJOR, XGBOOST_VER_MINOR #include // for equal, transform -#include // for int32_t, int64_t, uint32_t #include // for size_t #include // for ofstream -#include // for back_insert_iterator, back_inserter #include // for numeric_limits #include // for map #include // for unique_ptr, shared_ptr, __shared_ptr_... @@ -28,9 +26,9 @@ #include "../../src/common/io.h" // for LoadSequentialFile #include "../../src/common/linalg_op.h" // for ElementWiseTransformHost, begin, end #include "../../src/common/random.h" // for GlobalRandom +#include "./collective/test_worker.h" // for TestDistributedGlobal #include "dmlc/io.h" // for Stream #include "dmlc/omp.h" // for omp_get_max_threads -#include "dmlc/registry.h" // for Registry #include "filesystem.h" // for TemporaryDirectory #include "helpers.h" // for GetBaseScore, RandomDataGenerator #include "objective_helpers.h" // for MakeObjNamesForTest, ObjTestNameGenerator @@ -82,9 +80,7 @@ TEST(Learner, ParameterValidation) { // whitespace learner->SetParam("tree method", "exact"); - EXPECT_THAT([&] { learner->Configure(); }, - ::testing::ThrowsMessage( - ::testing::HasSubstr(R"("tree method" contains whitespace)"))); + ASSERT_THAT([&] { learner->Configure(); }, GMockThrow(R"("tree method" contains whitespace)")); } TEST(Learner, CheckGroup) { @@ -105,9 +101,9 @@ TEST(Learner, CheckGroup) { labels[i] = i % 2; } - p_mat->SetInfo("weight", static_cast(weight.data()), DataType::kFloat32, kNumGroups); - p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups); - p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kNumRows); + p_mat->SetInfo("weight", Make1dInterfaceTest(weight.data(), kNumGroups)); + p_mat->SetInfo("group", Make1dInterfaceTest(group.data(), kNumGroups)); + p_mat->SetInfo("label", Make1dInterfaceTest(labels.data(), kNumRows)); std::vector> mat = {p_mat}; auto learner = std::unique_ptr(Learner::Create(mat)); @@ -117,7 +113,7 @@ TEST(Learner, CheckGroup) { group.resize(kNumGroups+1); group[3] = 4; group[4] = 1; - p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1); + p_mat->SetInfo("group", Make1dInterfaceTest(group.data(), kNumGroups+1)); EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat)); } @@ -134,7 +130,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT for (size_t i = 0; i < num_row; ++i) { labels[i] = i % 2; } - dmat->SetInfo("label", labels.data(), DataType::kFloat32, num_row); + dmat->SetInfo("label", Make1dInterfaceTest(labels.data(), num_row)); std::vector> mat{dmat}; auto learner = std::unique_ptr(Learner::Create(mat)); learner->SetParams(Args{{"objective", "binary:logistic"}}); @@ -219,7 +215,7 @@ TEST(Learner, JsonModelIO) { } TEST(Learner, ConfigIO) { - bst_row_t n_samples = 128; + bst_idx_t n_samples = 128; bst_feature_t n_features = 12; std::shared_ptr p_fmat{ RandomDataGenerator{n_samples, n_features, 0}.GenerateDMatrix(true, false, 2)}; @@ -662,7 +658,7 @@ class TestColumnSplit : public ::testing::TestWithParam { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); - auto p_fmat = MakeFmatForObjTest(objective); + auto p_fmat = MakeFmatForObjTest(objective, 10, 10); std::shared_ptr sliced{p_fmat->SliceCol(world_size, rank)}; std::unique_ptr learner{Learner::Create({sliced})}; learner->SetParam("tree_method", "approx"); @@ -686,7 +682,7 @@ class TestColumnSplit : public ::testing::TestWithParam { public: void Run(std::string objective) { - auto p_fmat = MakeFmatForObjTest(objective); + auto p_fmat = MakeFmatForObjTest(objective, 10, 10); std::unique_ptr learner{Learner::Create({p_fmat})}; learner->SetParam("tree_method", "approx"); learner->SetParam("objective", objective); @@ -707,7 +703,9 @@ class TestColumnSplit : public ::testing::TestWithParam { auto constexpr kWorldSize{3}; auto call = [this, &objective](auto&... args) { TestBaseScore(objective, args...); }; auto score = GetBaseScore(config); - RunWithInMemoryCommunicator(kWorldSize, call, score, model); + collective::TestDistributedGlobal(kWorldSize, [&] { + call(score, model); + }); } }; @@ -740,7 +738,7 @@ void VerifyColumnSplitWithArgs(std::string const& tree_method, bool use_gpu, Arg Json const& expected_model) { auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); - auto p_fmat = MakeFmatForObjTest(""); + auto p_fmat = MakeFmatForObjTest("", 10, 10); std::shared_ptr sliced{p_fmat->SliceCol(world_size, rank)}; std::string device = "cpu"; if (use_gpu) { @@ -751,82 +749,99 @@ void VerifyColumnSplitWithArgs(std::string const& tree_method, bool use_gpu, Arg ASSERT_EQ(model, expected_model); } -void TestColumnSplitWithArgs(std::string const& tree_method, bool use_gpu, Args const& args) { - auto p_fmat = MakeFmatForObjTest(""); +void TestColumnSplitWithArgs(std::string const& tree_method, bool use_gpu, Args const& args, + bool federated) { + auto p_fmat = MakeFmatForObjTest("", 10, 10); std::string device = use_gpu ? "cuda:0" : "cpu"; auto model = GetModelWithArgs(p_fmat, tree_method, device, args); auto world_size{3}; if (use_gpu) { world_size = common::AllVisibleGPUs(); - // Simulate MPU on a single GPU. - if (world_size == 1) { + // Simulate MPU on a single GPU. Federated doesn't use nccl, can run multiple + // instances on the same GPU. + if (world_size == 1 && federated) { world_size = 3; } } - RunWithInMemoryCommunicator(world_size, VerifyColumnSplitWithArgs, tree_method, use_gpu, args, - model); -} - -void TestColumnSplitColumnSampler(std::string const& tree_method, bool use_gpu) { - Args args{{"colsample_bytree", "0.5"}, {"colsample_bylevel", "0.6"}, {"colsample_bynode", "0.7"}}; - TestColumnSplitWithArgs(tree_method, use_gpu, args); -} - -void TestColumnSplitInteractionConstraints(std::string const& tree_method, bool use_gpu) { - Args args{{"interaction_constraints", "[[0, 5, 7], [2, 8, 9], [1, 3, 6]]"}}; - TestColumnSplitWithArgs(tree_method, use_gpu, args); -} - -void TestColumnSplitMonotoneConstraints(std::string const& tree_method, bool use_gpu) { - Args args{{"monotone_constraints", "(1,-1,0,1,1,-1,-1,0,0,1)"}}; - TestColumnSplitWithArgs(tree_method, use_gpu, args); -} -} // anonymous namespace - -TEST(ColumnSplitColumnSampler, Approx) { TestColumnSplitColumnSampler("approx", false); } - -TEST(ColumnSplitColumnSampler, Hist) { TestColumnSplitColumnSampler("hist", false); } - -#if defined(XGBOOST_USE_CUDA) -TEST(MGPUColumnSplitColumnSampler, GPUApprox) { TestColumnSplitColumnSampler("approx", true); } - -TEST(MGPUColumnSplitColumnSampler, GPUHist) { TestColumnSplitColumnSampler("hist", true); } -#endif // defined(XGBOOST_USE_CUDA) - -TEST(ColumnSplitInteractionConstraints, Approx) { - TestColumnSplitInteractionConstraints("approx", false); + if (federated) { +#if defined(XGBOOST_USE_FEDERATED) + collective::TestFederatedGlobal( + world_size, [&] { VerifyColumnSplitWithArgs(tree_method, use_gpu, args, model); }); +#else + GTEST_SKIP_("Not compiled with federated learning."); +#endif // defined(XGBOOST_USE_FEDERATED) + } else { +#if !defined(XGBOOST_USE_NCCL) + if (use_gpu) { + GTEST_SKIP_("Not compiled with NCCL."); + return; + } +#endif // defined(XGBOOST_USE_NCCL) + collective::TestDistributedGlobal( + world_size, [&] { VerifyColumnSplitWithArgs(tree_method, use_gpu, args, model); }); + } } -TEST(ColumnSplitInteractionConstraints, Hist) { - TestColumnSplitInteractionConstraints("hist", false); -} +class ColumnSplitTrainingTest + : public ::testing::TestWithParam> { + public: + static void TestColumnSplitColumnSampler(std::string const& tree_method, bool use_gpu, + bool federated) { + Args args{ + {"colsample_bytree", "0.5"}, {"colsample_bylevel", "0.6"}, {"colsample_bynode", "0.7"}}; + TestColumnSplitWithArgs(tree_method, use_gpu, args, federated); + } + static void TestColumnSplitInteractionConstraints(std::string const& tree_method, bool use_gpu, + bool federated) { + Args args{{"interaction_constraints", "[[0, 5, 7], [2, 8, 9], [1, 3, 6]]"}}; + TestColumnSplitWithArgs(tree_method, use_gpu, args, federated); + } + static void TestColumnSplitMonotoneConstraints(std::string const& tree_method, bool use_gpu, + bool federated) { + Args args{{"monotone_constraints", "(1,-1,0,1,1,-1,-1,0,0,1)"}}; + TestColumnSplitWithArgs(tree_method, use_gpu, args, federated); + } +}; +auto MakeParamsForTest() { + std::vector> configs; + for (auto tm : {"hist", "approx"}) { #if defined(XGBOOST_USE_CUDA) -TEST(MGPUColumnSplitInteractionConstraints, GPUApprox) { - TestColumnSplitInteractionConstraints("approx", true); -} - -TEST(MGPUColumnSplitInteractionConstraints, GPUHist) { - TestColumnSplitInteractionConstraints("hist", true); + std::array use_gpu{true, false}; +#else + std::array use_gpu{false}; +#endif + for (auto i : use_gpu) { +#if defined(XGBOOST_USE_FEDERATED) + std::array fed{true, false}; +#else + std::array fed{false}; +#endif + for (auto j : fed) { + configs.emplace_back(tm, i, j); + } + } + } + return configs; } -#endif // defined(XGBOOST_USE_CUDA) +} // anonymous namespace -TEST(ColumnSplitMonotoneConstraints, Approx) { - TestColumnSplitMonotoneConstraints("approx", false); +TEST_P(ColumnSplitTrainingTest, ColumnSampler) { + auto param = GetParam(); + std::apply(TestColumnSplitColumnSampler, param); } -TEST(ColumnSplitMonotoneConstraints, Hist) { - TestColumnSplitMonotoneConstraints("hist", false); +TEST_P(ColumnSplitTrainingTest, InteractionConstraints) { + auto param = GetParam(); + std::apply(TestColumnSplitInteractionConstraints, param); } -#if defined(XGBOOST_USE_CUDA) -TEST(MGPUColumnSplitMonotoneConstraints, GPUApprox) { - TestColumnSplitMonotoneConstraints("approx", true); +TEST_P(ColumnSplitTrainingTest, MonotoneConstraints) { + auto param = GetParam(); + std::apply(TestColumnSplitMonotoneConstraints, param); } -TEST(MGPUColumnSplitMonotoneConstraints, GPUHist) { - TestColumnSplitMonotoneConstraints("hist", true); -} -#endif // defined(XGBOOST_USE_CUDA) +INSTANTIATE_TEST_SUITE_P(ColumnSplit, ColumnSplitTrainingTest, + ::testing::ValuesIn(MakeParamsForTest())); } // namespace xgboost diff --git a/tests/cpp/test_main.cc b/tests/cpp/test_main.cc index b93329c2e788..37be97f08a03 100644 --- a/tests/cpp/test_main.cc +++ b/tests/cpp/test_main.cc @@ -1,15 +1,16 @@ -// Copyright by Contributors +/** + * Copyright 2016-2024, XGBoost Contributors + */ #include #include #include + #include -#include -#include #include "helpers.h" -int main(int argc, char ** argv) { - xgboost::Args args {{"verbosity", "2"}}; +int main(int argc, char** argv) { + xgboost::Args args{{"verbosity", "2"}}; xgboost::ConsoleLogger::Configure(args); testing::InitGoogleTest(&argc, argv); diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index 862bc6bfcca9..72a8b5449f0f 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -1,12 +1,12 @@ /** - * Copyright 2020-2023, XGBoost contributors + * Copyright 2020-2024, XGBoost contributors */ #include #include #include "../../../../src/tree/gpu_hist/evaluate_splits.cuh" +#include "../../collective/test_worker.h" // for BaseMGPUTest #include "../../helpers.h" -#include "../../histogram_helpers.h" #include "../test_evaluate_splits.h" // TestPartitionBasedSplit namespace xgboost::tree { @@ -17,13 +17,13 @@ auto ZeroParam() { tparam.UpdateAllowUnknown(args); return tparam; } -} // anonymous namespace -inline GradientQuantiser DummyRoundingFactor(Context const* ctx) { +GradientQuantiser DummyRoundingFactor(Context const* ctx) { thrust::device_vector gpair(1); gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000 return {ctx, dh::ToSpan(gpair), MetaInfo()}; } +} // anonymous namespace thrust::device_vector ConvertToInteger(Context const* ctx, std::vector x) { @@ -363,7 +363,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) { GPUTrainingParam param{tparam}; thrust::device_vector feature_set = std::vector{0}; - thrust::device_vector feature_segments = std::vector{0, 2}; + thrust::device_vector feature_segments = std::vector{0, 2}; thrust::device_vector feature_values = std::vector{1.0, 2.0}; thrust::device_vector feature_min_values = std::vector{0.0}; auto feature_histogram = ConvertToInteger(&ctx, {{-0.5, 0.5}, {0.5, 0.5}}); @@ -412,7 +412,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { GPUTrainingParam param{tparam}; thrust::device_vector feature_set = std::vector{1}; - thrust::device_vector feature_segments = std::vector{0, 2, 4}; + thrust::device_vector feature_segments = std::vector{0, 2, 4}; thrust::device_vector feature_values = std::vector{1.0, 2.0, 11.0, 12.0}; thrust::device_vector feature_min_values = std::vector{0.0, 10.0}; auto feature_histogram = @@ -446,7 +446,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) { GPUTrainingParam param{tparam}; thrust::device_vector feature_set = std::vector{0, 1}; - thrust::device_vector feature_segments = std::vector{0, 2, 4}; + thrust::device_vector feature_segments = std::vector{0, 2, 4}; thrust::device_vector feature_values = std::vector{1.0, 2.0, 11.0, 12.0}; thrust::device_vector feature_min_values = std::vector{0.0, 10.0}; auto feature_histogram = @@ -478,7 +478,7 @@ TEST(GpuHist, EvaluateSplits) { GPUTrainingParam param{tparam}; thrust::device_vector feature_set = std::vector{0, 1}; - thrust::device_vector feature_segments = std::vector{0, 2, 4}; + thrust::device_vector feature_segments = std::vector{0, 2, 4}; thrust::device_vector feature_values = std::vector{1.0, 2.0, 11.0, 12.0}; thrust::device_vector feature_min_values = std::vector{0.0, 0.0}; auto feature_histogram_left = @@ -546,7 +546,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) { ASSERT_NEAR(split.loss_chg, best_score_, 1e-2); } -class MGPUHistTest : public BaseMGPUTest {}; +class MGPUHistTest : public collective::BaseMGPUTest {}; namespace { void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) { @@ -589,21 +589,29 @@ void VerifyColumnSplitEvaluateSingleSplit(bool is_categorical) { evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, true, ctx.Device()); DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(&ctx, input, shared_inputs).split; - EXPECT_EQ(result.findex, 1) << "rank: " << rank; + EXPECT_EQ(result.findex, 1); if (is_categorical) { ASSERT_TRUE(std::isnan(result.fvalue)); } else { - EXPECT_EQ(result.fvalue, 11.0) << "rank: " << rank; + EXPECT_EQ(result.fvalue, 11.0); } - EXPECT_EQ(result.left_sum + result.right_sum, parent_sum) << "rank: " << rank; + EXPECT_EQ(result.left_sum + result.right_sum, parent_sum); } } // anonymous namespace TEST_F(MGPUHistTest, ColumnSplitEvaluateSingleSplit) { - DoTest(VerifyColumnSplitEvaluateSingleSplit, false); + if (common::AllVisibleGPUs() > 1) { + // We can't emulate multiple GPUs with NCCL. + this->DoTest([] { VerifyColumnSplitEvaluateSingleSplit(false); }, false, true); + } + this->DoTest([] { VerifyColumnSplitEvaluateSingleSplit(false); }, true, true); } TEST_F(MGPUHistTest, ColumnSplitEvaluateSingleCategoricalSplit) { - DoTest(VerifyColumnSplitEvaluateSingleSplit, true); + if (common::AllVisibleGPUs() > 1) { + // We can't emulate multiple GPUs with NCCL. + this->DoTest([] { VerifyColumnSplitEvaluateSingleSplit(true); }, false, true); + } + this->DoTest([] { VerifyColumnSplitEvaluateSingleSplit(true); }, true, true); } } // namespace xgboost::tree diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index f7f2e27ea1be..84cd956db094 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -239,4 +239,18 @@ void TestAtomicAdd() { TEST(Histogram, AtomicAddInt64) { TestAtomicAdd(); } + +TEST(Histogram, Quantiser) { + auto ctx = MakeCUDACtx(0); + std::size_t n_samples{16}; + HostDeviceVector gpair(n_samples, GradientPair{1.0, 1.0}); + gpair.SetDevice(ctx.Device()); + + auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo()); + for (auto v : gpair.ConstHostVector()) { + auto gh = quantiser.ToFloatingPoint(quantiser.ToFixedPoint(v)); + ASSERT_EQ(gh.GetGrad(), 1.0); + ASSERT_EQ(gh.GetHess(), 1.0); + } +} } // namespace xgboost::tree diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 8eab1cd9d987..6f1086798bb4 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -22,6 +22,9 @@ #include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam #include "../../../../src/tree/param.h" // for GradStats, TrainParam #include "../../helpers.h" // for RandomDataGenerator, AllThreadsFo... +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/test_worker.h" // for TestFederatedGlobal +#endif // defined(XGBOOST_USE_FEDERATED) namespace xgboost::tree { void TestEvaluateSplits(bool force_read_by_column) { @@ -290,6 +293,7 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) { GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()}); } +#if defined(XGBOOST_USE_FEDERATED) namespace { void DoTestEvaluateSplitsSecure(bool force_read_by_column) { Context ctx; @@ -364,9 +368,10 @@ void DoTestEvaluateSplitsSecure(bool force_read_by_column) { delete m; } -void TestEvaluateSplitsSecure (bool force_read_by_column) { +void TestEvaluateSplitsSecure(bool force_read_by_column) { auto constexpr kWorkers = 2; - RunWithInMemoryCommunicator(kWorkers, DoTestEvaluateSplitsSecure, force_read_by_column); + collective::TestFederatedGlobal(kWorkers, + [&] { DoTestEvaluateSplitsSecure(force_read_by_column); }); } } // anonymous namespace @@ -374,5 +379,5 @@ TEST(HistEvaluator, SecureEvaluate) { TestEvaluateSplitsSecure(false); TestEvaluateSplitsSecure(true); } - +#endif // defined(XGBOOST_USE_FEDERATED) } // namespace xgboost::tree diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 76428d1d83b4..fdf25d8021d2 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -14,7 +14,6 @@ #include // for max #include // for size_t #include // for int32_t, uint32_t -#include // for function #include // for back_inserter #include // for numeric_limits #include // for shared_ptr, allocator, unique_ptr @@ -33,6 +32,7 @@ #include "../../../../src/tree/hist/histogram.h" // for HistogramBuilder #include "../../../../src/tree/hist/param.h" // for HistMakerTrainParam #include "../../categorical_helpers.h" // for OneHotEncodeFeature +#include "../../collective/test_worker.h" // for TestDistributedGlobal #include "../../helpers.h" // for RandomDataGenerator, GenerateRa... namespace xgboost::tree { @@ -290,8 +290,10 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_ double hess = sol.GetHess(); if (is_distributed && (!is_col_split || (is_secure && is_col_split))) { // the solution also needs to be allreduce - collective::Allreduce(&grad, 1); - collective::Allreduce(&hess, 1); + collective::SafeColl( + collective::Allreduce(&ctx, linalg::MakeVec(&grad, 1), collective::Op::kSum)); + collective::SafeColl( + collective::Allreduce(&ctx, linalg::MakeVec(&hess, 1), collective::Op::kSum)); } ASSERT_NEAR(grad, histogram.Histogram()[nid][i].GetGrad(), kEps); ASSERT_NEAR(hess, histogram.Histogram()[nid][i].GetHess(), kEps); @@ -307,20 +309,21 @@ TEST(CPUHistogram, BuildHist) { TEST(CPUHistogram, BuildHistDist) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, false, false, false); - RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, true, false, false); + collective::TestDistributedGlobal(kWorkers, + [] { TestBuildHistogram(true, false, false, false); }); + collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, true, false, false); }); } TEST(CPUHistogram, BuildHistDistColSplit) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, true, true, false); - RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, false, true, false); + collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, false, true, false); }); + collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, true, true, false); }); } TEST(CPUHistogram, BuildHistDistColSplitSecure) { auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, true, true, true); - RunWithInMemoryCommunicator(kWorkers, TestBuildHistogram, true, false, true, true); + collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, true, true, true); }); + collective::TestDistributedGlobal(kWorkers, [] { TestBuildHistogram(true, false, true, true); }); } namespace { @@ -428,9 +431,9 @@ void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, boo batch_param.hess = hess; } - std::vector partition_size(1, 0); + std::vector partition_size(1, 0); bst_bin_t total_bins{0}; - bst_row_t n_samples{0}; + bst_idx_t n_samples{0}; auto gpair = GenerateRandomGradients(m->Info().num_row_, 0.0, 1.0); auto const &h_gpair = gpair.HostVector(); diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 38da629b1438..b2949e5952a2 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -1,15 +1,15 @@ /** - * Copyright 2021-2023 by XGBoost contributors. + * Copyright 2021-2024, XGBoost contributors. */ #include #include "../../../src/common/numeric.h" #include "../../../src/tree/common_row_partitioner.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" #include "test_partitioner.h" -namespace xgboost { -namespace tree { +namespace xgboost::tree { namespace { std::vector GenerateHess(size_t n_samples) { auto grad = GenerateRandomGradients(n_samples); @@ -145,8 +145,9 @@ TEST(Approx, PartitionerColSplit) { } auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, TestColumnSplitPartitioner, n_samples, base_rowid, Xy, - &hess, min_value, mid_value, mid_partitioner); + collective::TestDistributedGlobal(kWorkers, [&] { + TestColumnSplitPartitioner(n_samples, base_rowid, Xy, &hess, min_value, mid_value, + mid_partitioner); + }); } -} // namespace tree -} // namespace xgboost +} // namespace xgboost::tree diff --git a/tests/cpp/tree/test_evaluate_splits.h b/tests/cpp/tree/test_evaluate_splits.h index 6506b54e88f0..a25e75aef4a9 100644 --- a/tests/cpp/tree/test_evaluate_splits.h +++ b/tests/cpp/tree/test_evaluate_splits.h @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 by XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include #include // for GradientPairInternal, GradientPairPrecise @@ -14,7 +14,6 @@ #include // for numeric_limits #include // for iota #include // for make_tuple, tie, tuple -#include // for pair #include // for vector #include "../../../src/common/hist_util.h" // for HistogramCuts, HistCollection, GHistRow @@ -23,7 +22,6 @@ #include "../../../src/tree/param.h" // for TrainParam, GradStats #include "../../../src/tree/split_evaluator.h" // for TreeEvaluator #include "../helpers.h" // for SimpleLCG, SimpleRealUniformDistribution -#include "gtest/gtest_pred_impl.h" // for AssertionResult, ASSERT_EQ, ASSERT_TRUE namespace xgboost::tree { /** @@ -96,13 +94,11 @@ class TestPartitionBasedSplit : public ::testing::Test { // enumerate all possible partitions to find the optimal split do { - int32_t thresh; - float score; std::vector sorted_hist(node_hist.size()); for (size_t i = 0; i < sorted_hist.size(); ++i) { sorted_hist[i] = node_hist[sorted_idx_[i]]; } - std::tie(thresh, score) = enumerate({sorted_hist}, total_gpair_); + auto [thresh, score] = enumerate({sorted_hist}, total_gpair_); if (score > best_score_) { best_score_ = score; } diff --git a/tests/cpp/tree/test_fit_stump.cc b/tests/cpp/tree/test_fit_stump.cc index d9441fd6f3fd..720d878521ec 100644 --- a/tests/cpp/tree/test_fit_stump.cc +++ b/tests/cpp/tree/test_fit_stump.cc @@ -1,11 +1,12 @@ /** - * Copyright 2022-2023, XGBoost Contributors + * Copyright 2022-2024, XGBoost Contributors */ #include #include #include "../../src/common/linalg_op.h" #include "../../src/tree/fit_stump.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" namespace xgboost::tree { @@ -43,7 +44,7 @@ TEST(InitEstimation, FitStump) { #if defined(XGBOOST_USE_CUDA) TEST(InitEstimation, GPUFitStump) { Context ctx; - ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + ctx.UpdateAllowUnknown(Args{{"device", "cuda"}}); TestFitStump(&ctx); } #endif // defined(XGBOOST_USE_CUDA) @@ -51,6 +52,6 @@ TEST(InitEstimation, GPUFitStump) { TEST(InitEstimation, FitStumpColumnSplit) { Context ctx; auto constexpr kWorldSize{3}; - RunWithInMemoryCommunicator(kWorldSize, &TestFitStump, &ctx, DataSplitMode::kCol); + collective::TestDistributedGlobal(kWorldSize, [&] { TestFitStump(&ctx, DataSplitMode::kCol); }); } } // namespace xgboost::tree diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 6f937351ea23..c3a949008261 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -13,14 +13,19 @@ #include "../../../src/common/common.h" #include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl #include "../../../src/data/ellpack_page.h" // for EllpackPage -#include "../../../src/tree/param.h" // for TrainParam +#include "../../../src/tree/param.h" // for TrainParam #include "../../../src/tree/updater_gpu_hist.cu" -#include "../filesystem.h" // dmlc::TemporaryDirectory +#include "../collective/test_worker.h" // for BaseMGPUTest +#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" #include "../histogram_helpers.h" #include "xgboost/context.h" #include "xgboost/json.h" +#if defined(XGBOOST_USE_FEDERATED) +#include "../plugin/federated/test_worker.h" // for TestFederatedGlobal +#endif // defined(XGBOOST_USE_FEDERATED) + namespace xgboost::tree { TEST(GpuHist, DeviceHistogram) { // Ensures that node allocates correctly after reaching `kStopGrowingSize`. @@ -440,7 +445,7 @@ RegTree GetHistTree(Context const* ctx, DMatrix* dmat) { return tree; } -void VerifyHistColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) { +void VerifyHistColumnSplit(bst_idx_t rows, bst_feature_t cols, RegTree const& expected_tree) { Context ctx(MakeCUDACtx(GPUIDX)); auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true); @@ -458,9 +463,9 @@ void VerifyHistColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& ex } } // anonymous namespace -class MGPUHistTest : public BaseMGPUTest {}; +class MGPUHistTest : public collective::BaseMGPUTest {}; -TEST_F(MGPUHistTest, GPUHistColumnSplit) { +TEST_F(MGPUHistTest, HistColumnSplit) { auto constexpr kRows = 32; auto constexpr kCols = 16; @@ -468,7 +473,8 @@ TEST_F(MGPUHistTest, GPUHistColumnSplit) { auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); RegTree expected_tree = GetHistTree(&ctx, dmat.get()); - DoTest(VerifyHistColumnSplit, kRows, kCols, expected_tree); + this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, true); + this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, false); } namespace { @@ -490,7 +496,7 @@ RegTree GetApproxTree(Context const* ctx, DMatrix* dmat) { return tree; } -void VerifyApproxColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) { +void VerifyApproxColumnSplit(bst_idx_t rows, bst_feature_t cols, RegTree const& expected_tree) { Context ctx(MakeCUDACtx(GPUIDX)); auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true); @@ -508,7 +514,7 @@ void VerifyApproxColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& } } // anonymous namespace -class MGPUApproxTest : public BaseMGPUTest {}; +class MGPUApproxTest : public collective::BaseMGPUTest {}; TEST_F(MGPUApproxTest, GPUApproxColumnSplit) { auto constexpr kRows = 32; @@ -518,6 +524,7 @@ TEST_F(MGPUApproxTest, GPUApproxColumnSplit) { auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); RegTree expected_tree = GetApproxTree(&ctx, dmat.get()); - DoTest(VerifyApproxColumnSplit, kRows, kCols, expected_tree); + this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, true); + this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, false); } } // namespace xgboost::tree diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc index 963660f59eda..b8b9e46cac18 100644 --- a/tests/cpp/tree/test_histmaker.cc +++ b/tests/cpp/tree/test_histmaker.cc @@ -5,7 +5,8 @@ #include #include -#include "../../../src/tree/param.h" // for TrainParam +#include "../../../src/tree/param.h" // for TrainParam +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" namespace xgboost::tree { @@ -118,8 +119,8 @@ void TestColumnSplit(bool categorical) { } auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, kRows, kCols, categorical, - std::cref(expected_tree)); + collective::TestDistributedGlobal( + kWorldSize, [&] { VerifyColumnSplit(kRows, kCols, categorical, expected_tree); }); } } // anonymous namespace diff --git a/tests/cpp/tree/test_multi_target_tree_model.cc b/tests/cpp/tree/test_multi_target_tree_model.cc index 550b8837c1cd..39e4cb4b52f0 100644 --- a/tests/cpp/tree/test_multi_target_tree_model.cc +++ b/tests/cpp/tree/test_multi_target_tree_model.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023 by XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include #include // for Context @@ -7,23 +7,30 @@ #include // for RegTree namespace xgboost { -TEST(MultiTargetTree, JsonIO) { +namespace { +auto MakeTreeForTest() { bst_target_t n_targets{3}; bst_feature_t n_features{4}; - RegTree tree{n_targets, n_features}; - ASSERT_TRUE(tree.IsMultiTarget()); + std::unique_ptr tree{std::make_unique(n_targets, n_features)}; + CHECK(tree->IsMultiTarget()); linalg::Vector base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()}; - tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(), - left_weight.HostView(), right_weight.HostView()); - ASSERT_EQ(tree.NumNodes(), 3); - ASSERT_EQ(tree.NumTargets(), 3); - ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3); - ASSERT_EQ(tree.Size(), 3); + tree->ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(), + left_weight.HostView(), right_weight.HostView()); + return tree; +} +} // namespace + +TEST(MultiTargetTree, JsonIO) { + auto tree = MakeTreeForTest(); + ASSERT_EQ(tree->NumNodes(), 3); + ASSERT_EQ(tree->NumTargets(), 3); + ASSERT_EQ(tree->GetMultiTargetTree()->Size(), 3); + ASSERT_EQ(tree->Size(), 3); Json jtree{Object{}}; - tree.SaveModel(&jtree); + tree->SaveModel(&jtree); auto check_jtree = [](Json jtree, RegTree const& tree) { ASSERT_EQ(get(jtree["tree_param"]["num_nodes"]), std::to_string(tree.NumNodes())); @@ -33,7 +40,7 @@ TEST(MultiTargetTree, JsonIO) { ASSERT_EQ(get(jtree["left_children"]).size(), tree.NumNodes()); ASSERT_EQ(get(jtree["right_children"]).size(), tree.NumNodes()); }; - check_jtree(jtree, tree); + check_jtree(jtree, *tree); RegTree loaded; loaded.LoadModel(jtree); @@ -42,6 +49,30 @@ TEST(MultiTargetTree, JsonIO) { Json jtree1{Object{}}; loaded.SaveModel(&jtree1); - check_jtree(jtree1, tree); + check_jtree(jtree1, *tree); +} + +TEST(MultiTargetTree, DumpDot) { + auto tree = MakeTreeForTest(); + auto n_features = tree->NumFeatures(); + FeatureMap fmap; + for (bst_feature_t f = 0; f < n_features; ++f) { + auto name = "feat_" + std::to_string(f); + fmap.PushBack(f, name.c_str(), "q"); + } + auto str = tree->DumpModel(fmap, true, "dot"); + ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos); + ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos); + + { + bst_target_t n_targets{4}; + bst_feature_t n_features{4}; + RegTree tree{n_targets, n_features}; + linalg::Vector weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()}; + tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(), + weight.HostView(), weight.HostView()); + auto str = tree.DumpModel(fmap, true, "dot"); + ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos); + } } } // namespace xgboost diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 4021c9959440..ce637caa4d46 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -13,6 +13,7 @@ #include "../../../src/tree/common_row_partitioner.h" #include "../../../src/tree/hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry #include "../../../src/tree/param.h" +#include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" #include "test_partitioner.h" #include "xgboost/data.h" @@ -190,9 +191,10 @@ void TestColumnSplitPartitioner(bst_target_t n_targets) { } auto constexpr kWorkers = 4; - RunWithInMemoryCommunicator(kWorkers, VerifyColumnSplitPartitioner, n_targets, - n_samples, n_features, base_rowid, Xy, min_value, mid_value, - mid_partitioner); + collective::TestDistributedGlobal(kWorkers, [&] { + VerifyColumnSplitPartitioner(n_targets, n_samples, n_features, base_rowid, Xy, + min_value, mid_value, mid_partitioner); + }); } } // anonymous namespace @@ -201,7 +203,7 @@ TEST(QuantileHist, PartitionerColSplit) { TestColumnSplitPartitioner(3); } namespace { -void VerifyColumnSplit(Context const* ctx, bst_row_t rows, bst_feature_t cols, bst_target_t n_targets, +void VerifyColumnSplit(Context const* ctx, bst_idx_t rows, bst_feature_t cols, bst_target_t n_targets, RegTree const& expected_tree) { auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true); linalg::Matrix gpair = GenerateRandomGradients(ctx, rows, n_targets); @@ -245,8 +247,9 @@ void TestColumnSplit(bst_target_t n_targets) { } auto constexpr kWorldSize = 2; - RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, &ctx, kRows, kCols, n_targets, - std::cref(expected_tree)); + collective::TestDistributedGlobal(kWorldSize, [&] { + VerifyColumnSplit(&ctx, kRows, kCols, n_targets, std::cref(expected_tree)); + }); } } // anonymous namespace diff --git a/tests/cpp/tree/test_refresh.cc b/tests/cpp/tree/test_refresh.cc index c8859c898519..bbd274a08d0f 100644 --- a/tests/cpp/tree/test_refresh.cc +++ b/tests/cpp/tree/test_refresh.cc @@ -15,7 +15,7 @@ namespace xgboost::tree { TEST(Updater, Refresh) { - bst_row_t constexpr kRows = 8; + bst_idx_t constexpr kRows = 8; bst_feature_t constexpr kCols = 16; Context ctx; diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 8707af0c885f..fd7c9d745db0 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -10,7 +10,7 @@ cudf = pytest.importorskip("cudf") -def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN): +def dmatrix_from_cudf(input_type, DMatrixT, missing=np.nan): """Test constructing DMatrix from cudf""" import pandas as pd @@ -38,8 +38,8 @@ def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN): def _test_from_cudf(DMatrixT): """Test constructing DMatrix from cudf""" - dmatrix_from_cudf(np.float32, DMatrixT, np.NAN) - dmatrix_from_cudf(np.float64, DMatrixT, np.NAN) + dmatrix_from_cudf(np.float32, DMatrixT, np.nan) + dmatrix_from_cudf(np.float64, DMatrixT, np.nan) dmatrix_from_cudf(np.int8, DMatrixT, 2) dmatrix_from_cudf(np.int32, DMatrixT, -2) @@ -66,20 +66,11 @@ def _test_from_cudf(DMatrixT): ) # Test when number of elements is less than 8 - X = cudf.DataFrame({"x": cudf.Series([0, 1, 2, np.NAN, 4], dtype=np.int32)}) + X = cudf.DataFrame({"x": cudf.Series([0, 1, 2, np.nan, 4], dtype=np.int32)}) dtrain = DMatrixT(X) assert dtrain.num_col() == 1 assert dtrain.num_row() == 5 - # Boolean is not supported. - X_boolean = cudf.DataFrame({"x": cudf.Series([True, False])}) - with pytest.raises(Exception): - dtrain = DMatrixT(X_boolean) - - y_boolean = cudf.DataFrame({"x": cudf.Series([True, False, True, True, True])}) - with pytest.raises(Exception): - dtrain = DMatrixT(X_boolean, label=y_boolean) - def _test_cudf_training(DMatrixT): import pandas as pd @@ -234,7 +225,7 @@ def test_cudf_categorical(self) -> None: assert len(interfaces) == X.shape[1] # test missing value - X = cudf.DataFrame({"f0": ["a", "b", np.NaN]}) + X = cudf.DataFrame({"f0": ["a", "b", np.nan]}) X["f0"] = X["f0"].astype("category") df, cat_codes, _, _ = xgb.data._transform_cudf_df( X, None, None, enable_categorical=True diff --git a/tests/python-gpu/test_from_cupy.py b/tests/python-gpu/test_from_cupy.py index 85d54c78dbff..7deb9aac08ba 100644 --- a/tests/python-gpu/test_from_cupy.py +++ b/tests/python-gpu/test_from_cupy.py @@ -18,7 +18,7 @@ def test_array_interface() -> None: np.testing.assert_equal(cp.asnumpy(arr), cp.asnumpy(ret)) -def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN): +def dmatrix_from_cupy(input_type, DMatrixT, missing=np.nan): """Test constructing DMatrix from cupy""" kRows = 80 kCols = 3 @@ -46,9 +46,9 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN): def _test_from_cupy(DMatrixT): """Test constructing DMatrix from cupy""" - dmatrix_from_cupy(np.float16, DMatrixT, np.NAN) - dmatrix_from_cupy(np.float32, DMatrixT, np.NAN) - dmatrix_from_cupy(np.float64, DMatrixT, np.NAN) + dmatrix_from_cupy(np.float16, DMatrixT, np.nan) + dmatrix_from_cupy(np.float32, DMatrixT, np.nan) + dmatrix_from_cupy(np.float64, DMatrixT, np.nan) dmatrix_from_cupy(np.uint8, DMatrixT, 2) dmatrix_from_cupy(np.uint32, DMatrixT, 3) diff --git a/tests/python-gpu/test_gpu_ranking.py b/tests/python-gpu/test_gpu_ranking.py index 2579b17ded3e..b7c5c3adb08d 100644 --- a/tests/python-gpu/test_gpu_ranking.py +++ b/tests/python-gpu/test_gpu_ranking.py @@ -6,6 +6,7 @@ import xgboost from xgboost import testing as tm +from xgboost.testing.ranking import run_normalization pytestmark = tm.timeout(30) @@ -126,3 +127,7 @@ def test_with_mq2008(objective, metric) -> None: dtest = xgboost.DMatrix(x_test, y_test, qid=qid_test) comp_training_with_rank_objective(dtrain, dtest, objective, metric) + + +def test_normalization() -> None: + run_normalization("cuda") diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index f7de0400d21f..a3923e9df4e4 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -1,16 +1,14 @@ import multiprocessing import socket import sys -import time +from threading import Thread import numpy as np import pytest import xgboost as xgb from xgboost import RabitTracker, build_info, federated - -if sys.platform.startswith("win"): - pytest.skip("Skipping collective tests on Windows", allow_module_level=True) +from xgboost import testing as tm def run_rabit_worker(rabit_env, world_size): @@ -18,20 +16,21 @@ def run_rabit_worker(rabit_env, world_size): assert xgb.collective.get_world_size() == world_size assert xgb.collective.is_distributed() assert xgb.collective.get_processor_name() == socket.gethostname() - ret = xgb.collective.broadcast('test1234', 0) - assert str(ret) == 'test1234' + ret = xgb.collective.broadcast("test1234", 0) + assert str(ret) == "test1234" ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM) assert np.array_equal(ret, np.asarray([2, 4, 6])) -def test_rabit_communicator(): +def test_rabit_communicator() -> None: world_size = 2 - tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size) - tracker.start(world_size) + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size) + tracker.start() workers = [] for _ in range(world_size): - worker = multiprocessing.Process(target=run_rabit_worker, - args=(tracker.worker_envs(), world_size)) + worker = multiprocessing.Process( + target=run_rabit_worker, args=(tracker.worker_args(), world_size) + ) workers.append(worker) worker.start() for worker in workers: @@ -39,39 +38,44 @@ def test_rabit_communicator(): assert worker.exitcode == 0 -def run_federated_worker(port, world_size, rank): - with xgb.collective.CommunicatorContext(xgboost_communicator='federated', - federated_server_address=f'localhost:{port}', - federated_world_size=world_size, - federated_rank=rank): +def run_federated_worker(port: int, world_size: int, rank: int) -> None: + with xgb.collective.CommunicatorContext( + dmlc_communicator="federated", + federated_server_address=f"localhost:{port}", + federated_world_size=world_size, + federated_rank=rank, + ): assert xgb.collective.get_world_size() == world_size assert xgb.collective.is_distributed() - assert xgb.collective.get_processor_name() == f'rank{rank}' - ret = xgb.collective.broadcast('test1234', 0) - assert str(ret) == 'test1234' - ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM) - assert np.array_equal(ret, np.asarray([2, 4, 6])) + assert xgb.collective.get_processor_name() == f"rank:{rank}" + bret = xgb.collective.broadcast("test1234", 0) + assert str(bret) == "test1234" + aret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM) + assert np.array_equal(aret, np.asarray([2, 4, 6])) +@pytest.mark.skipif(**tm.skip_win()) def test_federated_communicator(): if not build_info()["USE_FEDERATED"]: pytest.skip("XGBoost not built with federated learning enabled") port = 9091 world_size = 2 - server = multiprocessing.Process(target=xgb.federated.run_federated_server, args=(port, world_size)) - server.start() - time.sleep(1) - if not server.is_alive(): + tracker = multiprocessing.Process( + target=federated.run_federated_server, + kwargs={"port": port, "n_workers": world_size}, + ) + tracker.start() + if not tracker.is_alive(): raise Exception("Error starting Federated Learning server") workers = [] for rank in range(world_size): - worker = multiprocessing.Process(target=run_federated_worker, - args=(port, world_size, rank)) + worker = multiprocessing.Process( + target=run_federated_worker, args=(port, world_size, rank) + ) workers.append(worker) worker.start() for worker in workers: worker.join() assert worker.exitcode == 0 - server.terminate() diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index 174f5606cc27..7f0153565c4b 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -160,9 +160,11 @@ def test_data_iterator( class IterForCacheTest(xgb.DataIter): - def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray) -> None: + def __init__( + self, x: np.ndarray, y: np.ndarray, w: np.ndarray, release_data: bool + ) -> None: self.kwargs = {"data": x, "label": y, "weight": w} - super().__init__(release_data=False) + super().__init__(release_data=release_data) def next(self, input_data: Callable) -> int: if self.it == 1: @@ -181,7 +183,9 @@ def test_data_cache() -> None: n_samples_per_batch = 16 data = make_batches(n_samples_per_batch, n_features, n_batches, False) batches = [v[0] for v in data] - it = IterForCacheTest(*batches) + + # Test with a cache. + it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=False) transform = xgb.data._proxy_transform called = 0 @@ -196,6 +200,12 @@ def mock(*args: Any, **kwargs: Any) -> Any: assert it._data_ref is weakref.ref(batches[0]) assert called == 1 + # Test without a cache. + called = 0 + it = IterForCacheTest(batches[0], batches[1], batches[2], release_data=True) + xgb.QuantileDMatrix(it) + assert called == 4 + xgb.data._proxy_transform = transform diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 9d123ddb9dbc..fe60d9bfde50 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -147,7 +147,7 @@ def test_feature_names_slice(self): assert dm.slice([0, 1]).num_col() == dm.num_col() assert dm.slice([0, 1]).feature_names == dm.feature_names - with pytest.raises(ValueError, match=r"Duplicates found: \['bar'\]"): + with pytest.raises(ValueError, match=r"Duplicates found: \[.*'bar'.*\]"): dm.feature_names = ["bar"] * (data.shape[1] - 2) + ["a", "b"] dm.feature_types = list("qiqiq") @@ -264,7 +264,7 @@ def test_sparse_dmatrix_csr(self): assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol) watchlist = [(dtrain, "train")] param = {"max_depth": 3, "objective": "binary:logistic"} - bst = xgb.train(param, dtrain, 5, watchlist) + bst = xgb.train(param, dtrain, 5, evals=watchlist) bst.predict(dtrain) i32 = csr_matrix((x.data.astype(np.int32), x.indices, x.indptr), shape=x.shape) @@ -302,7 +302,7 @@ def test_sparse_dmatrix_csc(self): assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol) watchlist = [(dtrain, "train")] param = {"max_depth": 3, "objective": "binary:logistic"} - bst = xgb.train(param, dtrain, 5, watchlist) + bst = xgb.train(param, dtrain, 5, evals=watchlist) bst.predict(dtrain) def test_unknown_data(self): @@ -320,9 +320,10 @@ class Data: X = rng.rand(10, 10) y = rng.rand(10) X = sparse.dok_matrix(X) - Xy = xgb.DMatrix(X, y) - assert Xy.num_row() == 10 - assert Xy.num_col() == 10 + with pytest.warns(UserWarning, match="dok_matrix"): + Xy = xgb.DMatrix(X, y) + assert Xy.num_row() == 10 + assert Xy.num_col() == 10 @pytest.mark.skipif(**tm.no_pandas()) def test_np_categorical(self): @@ -343,8 +344,8 @@ def test_scipy_categorical(self): X = X.values.astype(np.float32) feature_types = ["c"] * n_features - X[1, 3] = np.NAN - X[2, 4] = np.NAN + X[1, 3] = np.nan + X[2, 4] = np.nan X = sparse.csr_matrix(X) Xy = xgb.DMatrix(X, y, feature_types=feature_types) diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index 6b59ef540d7b..4a81e807bfa3 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -241,7 +241,7 @@ def test_dtypes(self) -> None: # unsupported types for dtype in [ - np.string_, + np.bytes_, np.complex64, np.complex128, ]: diff --git a/tests/python/test_quantile_dmatrix.py b/tests/python/test_quantile_dmatrix.py index 28a7eb37a17b..2d9c15c8502f 100644 --- a/tests/python/test_quantile_dmatrix.py +++ b/tests/python/test_quantile_dmatrix.py @@ -333,7 +333,7 @@ def test_dtypes(self) -> None: # unsupported types for dtype in [ - np.string_, + np.bytes_, np.complex64, np.complex128, ]: diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index 8bdeb070ffbe..49508f594c52 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -13,6 +13,7 @@ from xgboost import testing as tm from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples from xgboost.testing.params import lambdarank_parameter_strategy +from xgboost.testing.ranking import run_normalization def test_ndcg_custom_gain(): @@ -53,6 +54,20 @@ def ndcg_gain(y: np.ndarray) -> np.ndarray: assert byxgb.evals_result() == bynp.evals_result() assert byxgb_json == bynp_json + # test pairwise can handle max_rel > 31, while ndcg metric is using custom gain + X, y, q, w = tm.make_ltr(n_samples=1024, n_features=4, n_query_groups=3, max_rel=33) + ranknet = xgboost.XGBRanker( + tree_method="hist", + ndcg_exp_gain=False, + n_estimators=10, + objective="rank:pairwise", + ) + ranknet.fit(X, y, qid=q, eval_set=[(X, y)], eval_qid=[q]) + history = ranknet.evals_result() + assert ( + history["validation_0"]["ndcg@32"][0] < history["validation_0"]["ndcg@32"][-1] + ) + def test_ranking_with_unweighted_data(): Xrow = np.array([1, 2, 6, 8, 11, 14, 16, 17]) @@ -188,6 +203,10 @@ def after_training(self, model) -> bool: assert df["ti+"].iloc[-1] < df["ti+"].iloc[0] +def test_normalization() -> None: + run_normalization("cpu") + + class TestRanking: @classmethod def setup_class(cls): diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 1f42711a20d9..5d508f0d17ff 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -3,33 +3,33 @@ import numpy as np import pytest +from hypothesis import HealthCheck, given, settings, strategies import xgboost as xgb from xgboost import RabitTracker, collective from xgboost import testing as tm -if sys.platform.startswith("win"): - pytest.skip("Skipping dask tests on Windows", allow_module_level=True) - def test_rabit_tracker(): tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1) - tracker.start(1) - with xgb.collective.CommunicatorContext(**tracker.worker_envs()): + tracker.start() + with xgb.collective.CommunicatorContext(**tracker.worker_args()): ret = xgb.collective.broadcast("test1234", 0) assert str(ret) == "test1234" @pytest.mark.skipif(**tm.not_linux()) def test_socket_error(): - tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1) - tracker.start(1) - env = tracker.worker_envs() - env["DMLC_TRACKER_PORT"] = 0 - env["DMLC_WORKER_CONNECT_RETRY"] = 1 - with pytest.raises(ValueError, match="127.0.0.1:0\n.*refused"): + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=2) + tracker.start() + env = tracker.worker_args() + env["dmlc_tracker_port"] = 0 + env["dmlc_retry"] = 1 + with pytest.raises(ValueError, match="Failed to bootstrap the communication."): with xgb.collective.CommunicatorContext(**env): pass + with pytest.raises(ValueError): + tracker.free() def run_rabit_ops(client, n_workers): @@ -70,6 +70,40 @@ def test_rabit_ops(): run_rabit_ops(client, n_workers) +def run_allreduce(client) -> None: + from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args + + workers = tm.get_client_workers(client) + rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client) + n_workers = len(workers) + + def local_test(worker_id: int) -> None: + x = np.full(shape=(1024 * 1024 * 32), fill_value=1.0) + with CommunicatorContext(**rabit_args): + k = np.asarray([1.0]) + for i in range(128): + m = collective.allreduce(k, collective.Op.SUM) + assert m == n_workers + + y = collective.allreduce(x, collective.Op.SUM) + np.testing.assert_allclose(y, np.full_like(y, fill_value=float(n_workers))) + + futures = client.map(local_test, range(len(workers)), workers=workers) + results = client.gather(futures) + + +@pytest.mark.skipif(**tm.no_dask()) +def test_allreduce() -> None: + from distributed import Client, LocalCluster + + n_workers = 4 + for i in range(2): + with LocalCluster(n_workers=n_workers) as cluster: + with Client(cluster) as client: + for i in range(2): + run_allreduce(client) + + def run_broadcast(client): from xgboost.dask import _get_dask_config, _get_rabit_args @@ -109,6 +143,7 @@ def test_rabit_ops_ipv6(): run_rabit_ops(client, n_workers) +@pytest.mark.skipif(**tm.no_dask()) def test_rank_assignment() -> None: from distributed import Client, LocalCluster @@ -133,3 +168,107 @@ def local_test(worker_id): futures = client.map(local_test, range(len(workers)), workers=workers) client.gather(futures) + + +@pytest.fixture +def local_cluster(): + from distributed import LocalCluster + + n_workers = 8 + with LocalCluster(n_workers=n_workers, dashboard_address=":0") as cluster: + yield cluster + + +ops_strategy = strategies.lists( + strategies.sampled_from(["broadcast", "allreduce_max", "allreduce_sum"]) +) + + +@pytest.mark.skipif(**tm.no_dask()) +@given(ops=ops_strategy, size=strategies.integers(2**4, 2**16)) +@settings( + deadline=None, + print_blob=True, + max_examples=10, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_ops_restart_comm(local_cluster, ops, size) -> None: + from distributed import Client + + def local_test(w: int, n_workers: int) -> None: + a = np.arange(0, n_workers) + with xgb.dask.CommunicatorContext(**args): + for op in ops: + if op == "broadcast": + b = collective.broadcast(a, root=1) + np.testing.assert_allclose(b, a) + elif op == "allreduce_max": + b = collective.allreduce(a, collective.Op.MAX) + np.testing.assert_allclose(b, a) + elif op == "allreduce_sum": + b = collective.allreduce(a, collective.Op.SUM) + np.testing.assert_allclose(a * n_workers, b) + else: + raise ValueError() + + with Client(local_cluster) as client: + workers = tm.get_client_workers(client) + args = client.sync( + xgb.dask._get_rabit_args, + len(workers), + None, + client, + ) + + workers = tm.get_client_workers(client) + n_workers = len(workers) + + futures = client.map( + local_test, range(len(workers)), workers=workers, n_workers=n_workers + ) + client.gather(futures) + + +@pytest.mark.skipif(**tm.no_dask()) +def test_ops_reuse_comm(local_cluster) -> None: + from distributed import Client + + rng = np.random.default_rng(1994) + n_examples = 10 + ops = rng.choice( + ["broadcast", "allreduce_sum", "allreduce_max"], size=n_examples + ).tolist() + + def local_test(w: int, n_workers: int) -> None: + a = np.arange(0, n_workers) + + with xgb.dask.CommunicatorContext(**args): + for op in ops: + if op == "broadcast": + b = collective.broadcast(a, root=1) + assert np.allclose(b, a) + elif op == "allreduce_max": + c = np.full_like(a, collective.get_rank()) + b = collective.allreduce(c, collective.Op.MAX) + assert np.allclose(b, n_workers - 1), b + elif op == "allreduce_sum": + b = collective.allreduce(a, collective.Op.SUM) + assert np.allclose(a * 8, b) + else: + raise ValueError() + + with Client(local_cluster) as client: + workers = tm.get_client_workers(client) + args = client.sync( + xgb.dask._get_rabit_args, + len(workers), + None, + client, + ) + + n_workers = len(workers) + + futures = client.map( + local_test, range(len(workers)), workers=workers, n_workers=n_workers + ) + client.gather(futures) diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index e7641348d98e..8ec1fdd9d395 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -35,10 +35,24 @@ class TestTreeMethod: def test_exact(self, param, num_rounds, dataset): if dataset.name.endswith("-l1"): return - param['tree_method'] = 'exact' + param["tree_method"] = "exact" param = dataset.set_params(param) result = train_result(param, dataset.get_dmat(), num_rounds) - assert tm.non_increasing(result['train'][dataset.metric]) + assert tm.non_increasing(result["train"][dataset.metric]) + + def test_exact_sample_by_node_error(self) -> None: + X, y, w = tm.make_regression(128, 12, False) + with pytest.raises(ValueError, match="column sample by node"): + xgb.train( + {"tree_method": "exact", "colsample_bynode": 0.999}, + xgb.DMatrix(X, y, weight=w), + ) + + xgb.train( + {"tree_method": "exact", "colsample_bynode": 1.0}, + xgb.DMatrix(X, y, weight=w), + num_boost_round=2, + ) @given( exact_parameter_strategy, diff --git a/tests/python/test_with_arrow.py b/tests/python/test_with_arrow.py index 4d12f32df5d3..145cc0f2b3d9 100644 --- a/tests/python/test_with_arrow.py +++ b/tests/python/test_with_arrow.py @@ -8,19 +8,14 @@ from xgboost import testing as tm from xgboost.core import DataSplitMode -try: - import pandas as pd - import pyarrow as pa - import pyarrow.csv as pc -except ImportError: - pass - pytestmark = pytest.mark.skipif( tm.no_arrow()["condition"] or tm.no_pandas()["condition"], reason=tm.no_arrow()["reason"] + " or " + tm.no_pandas()["reason"], ) -dpath = "demo/data/" +import pandas as pd +import pyarrow as pa +import pyarrow.csv as pc class TestArrowTable: diff --git a/tests/python/test_with_modin.py b/tests/python/test_with_modin.py index ce0dbd609215..875c5f7f18c5 100644 --- a/tests/python/test_with_modin.py +++ b/tests/python/test_with_modin.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import pytest import xgboost as xgb @@ -71,7 +72,11 @@ def test_modin(self) -> None: np.testing.assert_array_equal(result.columns, exp) dm = xgb.DMatrix(dummies) assert dm.feature_names == ['B', 'A_X', 'A_Y', 'A_Z'] - assert dm.feature_types == ['int', 'int', 'int', 'int'] + if int(pd.__version__[0]) >= 2: + assert dm.feature_types == ["int", "i", "i", "i"] + else: + assert dm.feature_types == ["int", "int", "int", "int"] + assert dm.num_row() == 3 assert dm.num_col() == 4 diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index e53e7adccc1f..27be831d3f88 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -248,7 +248,7 @@ def test_pandas_categorical(self, data_split_mode=DataSplitMode.ROW): assert transformed.columns[0].min() == 0 # test missing value - X = pd.DataFrame({"f0": ["a", "b", np.NaN]}) + X = pd.DataFrame({"f0": ["a", "b", np.nan]}) X["f0"] = X["f0"].astype("category") arr, _, _ = xgb.data._transform_pandas_df(X, enable_categorical=True) for c in arr.columns: @@ -280,10 +280,12 @@ def test_pandas_sparse(self): } ) y = pd.Series(pd.arrays.SparseArray(np.random.randn(rows))) - dtrain = xgb.DMatrix(X, y) + with pytest.warns(UserWarning, match="Sparse arrays from pandas"): + dtrain = xgb.DMatrix(X, y) booster = xgb.train({}, dtrain, num_boost_round=4) - predt_sparse = booster.predict(xgb.DMatrix(X)) - predt_dense = booster.predict(xgb.DMatrix(X.sparse.to_dense())) + with pytest.warns(UserWarning, match="Sparse arrays from pandas"): + predt_sparse = booster.predict(xgb.DMatrix(X)) + predt_dense = booster.predict(xgb.DMatrix(X.sparse.to_dense())) np.testing.assert_allclose(predt_sparse, predt_dense) def test_pandas_label( @@ -572,14 +574,16 @@ def test_pandas_sparse_column_split(self): y = pd.Series(pd.arrays.SparseArray(np.random.randn(rows))) def verify_pandas_sparse(): - dtrain = xgb.DMatrix(X, y, data_split_mode=DataSplitMode.COL) + with pytest.warns(UserWarning, match="Sparse arrays from pandas"): + dtrain = xgb.DMatrix(X, y, data_split_mode=DataSplitMode.COL) booster = xgb.train({}, dtrain, num_boost_round=4) - predt_sparse = booster.predict( - xgb.DMatrix(X, data_split_mode=DataSplitMode.COL) - ) - predt_dense = booster.predict( - xgb.DMatrix(X.sparse.to_dense(), data_split_mode=DataSplitMode.COL) - ) + with pytest.warns(UserWarning, match="Sparse arrays from pandas"): + predt_sparse = booster.predict( + xgb.DMatrix(X, data_split_mode=DataSplitMode.COL) + ) + predt_dense = booster.predict( + xgb.DMatrix(X.sparse.to_dense(), data_split_mode=DataSplitMode.COL) + ) np.testing.assert_allclose(predt_sparse, predt_dense) tm.run_with_rabit(world_size=3, test_fn=verify_pandas_sparse) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 5074707241ba..61f33832ab48 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1098,9 +1098,10 @@ def test_pandas_input(): np.testing.assert_equal(model.feature_names_in_, np.array(feature_names)) columns = list(train.columns) - random.shuffle(columns, lambda: 0.1) + rng.shuffle(columns) df_incorrect = df[columns] - with pytest.raises(ValueError): + + with pytest.raises(ValueError, match="feature_names mismatch"): model.predict(df_incorrect) clf_isotonic = CalibratedClassifierCV(model, cv="prefit", method="isotonic") @@ -1300,20 +1301,12 @@ def test_estimator_reg(estimator, check): ): estimator.fit(X, y) return - if ( - os.environ["PYTEST_CURRENT_TEST"].find("check_estimators_overwrite_params") - != -1 - ): - # A hack to pass the scikit-learn parameter mutation tests. XGBoost regressor - # returns actual internal default values for parameters in `get_params`, but - # those are set as `None` in sklearn interface to avoid duplication. So we fit - # a dummy model and obtain the default parameters here for the mutation tests. - from sklearn.datasets import make_regression - - X, y = make_regression(n_samples=2, n_features=1) - estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params()) - - check(estimator) + elif os.environ["PYTEST_CURRENT_TEST"].find("check_regressor_multioutput") != -1: + # sklearn requires float64 + with pytest.raises(AssertionError, match="Got float32"): + check(estimator) + else: + check(estimator) def test_categorical(): @@ -1475,3 +1468,19 @@ def test_fit_none() -> None: with pytest.raises(ValueError, match="labels"): xgb.XGBRegressor().fit(X, None) + + +def test_tags() -> None: + for reg in [xgb.XGBRegressor(), xgb.XGBRFRegressor()]: + tags = reg._more_tags() + assert "non_deterministic" not in tags + assert tags["multioutput"] is True + assert tags["multioutput_only"] is False + + for clf in [xgb.XGBClassifier()]: + tags = clf._more_tags() + assert "multioutput" not in tags + assert tags["multilabel"] is True + + tags = xgb.XGBRanker()._more_tags() + assert "multioutput" not in tags diff --git a/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py b/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py index 9949cbf79bd5..ced78a84b0d2 100644 --- a/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py +++ b/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py @@ -252,7 +252,7 @@ def test_categorical(self, local_cuda_client: Client) -> None: X_onehot, _ = make_categorical(local_cuda_client, 10000, 30, 13, True) X_onehot = dask_cudf.from_dask_dataframe(X_onehot) - run_categorical(local_cuda_client, "gpu_hist", X, X_onehot, y) + run_categorical(local_cuda_client, "hist", "cuda", X, X_onehot, y) @given( params=hist_parameter_strategy, diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index ffea1d058bf9..ca55716bbd62 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -315,8 +315,15 @@ def test_dask_sparse(client: "Client") -> None: ) -def run_categorical(client: "Client", tree_method: str, X, X_onehot, y) -> None: - parameters = {"tree_method": tree_method, "max_cat_to_onehot": 9999} # force onehot +def run_categorical( + client: "Client", tree_method: str, device: str, X, X_onehot, y +) -> None: + # Force onehot + parameters = { + "tree_method": tree_method, + "device": device, + "max_cat_to_onehot": 9999, + } rounds = 10 m = xgb.dask.DaskDMatrix(client, X_onehot, y, enable_categorical=True) by_etl_results = xgb.dask.train( @@ -364,6 +371,7 @@ def check_model_output(model: xgb.dask.Booster) -> None: enable_categorical=True, n_estimators=10, tree_method=tree_method, + device=device, # force onehot max_cat_to_onehot=9999, ) @@ -378,7 +386,10 @@ def check_model_output(model: xgb.dask.Booster) -> None: reg.fit(X, y) # check partition based reg = xgb.dask.DaskXGBRegressor( - enable_categorical=True, n_estimators=10, tree_method=tree_method + enable_categorical=True, + n_estimators=10, + tree_method=tree_method, + device=device, ) reg.fit(X, y, eval_set=[(X, y)]) assert tm.non_increasing(reg.evals_result()["validation_0"]["rmse"]) @@ -398,8 +409,8 @@ def check_model_output(model: xgb.dask.Booster) -> None: def test_categorical(client: "Client") -> None: X, y = make_categorical(client, 10000, 30, 13) X_onehot, _ = make_categorical(client, 10000, 30, 13, True) - run_categorical(client, "approx", X, X_onehot, y) - run_categorical(client, "hist", X, X_onehot, y) + run_categorical(client, "approx", "cpu", X, X_onehot, y) + run_categorical(client, "hist", "cpu", X, X_onehot, y) ft = ["c"] * X.shape[1] reg = xgb.dask.DaskXGBRegressor( diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index b8c16ef1cd05..feb7b18bc035 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -929,8 +929,127 @@ def test_gpu_transform(self, clf_data: ClfData) -> None: model_loaded.set_device("cuda") assert model_loaded._run_on_gpu() + def test_validate_gpu_params(self) -> None: + # Standalone + standalone_conf = ( + SparkConf() + .setMaster("spark://foo") + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + .set("spark.task.resource.gpu.amount", "0.08") + ) + classifer_on_cpu = SparkXGBClassifier(use_gpu=False) + classifer_on_gpu = SparkXGBClassifier(use_gpu=True) + + # No exception for classifier on CPU + classifer_on_cpu._validate_gpu_params("3.4.0", standalone_conf) + + with pytest.raises( + ValueError, match="XGBoost doesn't support GPU fractional configurations" + ): + classifer_on_gpu._validate_gpu_params("3.3.0", standalone_conf) + + # No issues + classifer_on_gpu._validate_gpu_params("3.4.0", standalone_conf) + classifer_on_gpu._validate_gpu_params("3.4.1", standalone_conf) + classifer_on_gpu._validate_gpu_params("3.5.0", standalone_conf) + classifer_on_gpu._validate_gpu_params("3.5.1", standalone_conf) + + # no spark.executor.resource.gpu.amount + standalone_bad_conf = ( + SparkConf() + .setMaster("spark://foo") + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.task.resource.gpu.amount", "0.08") + ) + msg_match = ( + "The `spark.executor.resource.gpu.amount` is required for training on GPU" + ) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.3.0", standalone_bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.4.0", standalone_bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.4.1", standalone_bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.5.0", standalone_bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.5.1", standalone_bad_conf) + + standalone_bad_conf = ( + SparkConf() + .setMaster("spark://foo") + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + ) + msg_match = ( + "The `spark.task.resource.gpu.amount` is required for training on GPU" + ) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.3.0", standalone_bad_conf) + + classifer_on_gpu._validate_gpu_params("3.4.0", standalone_bad_conf) + classifer_on_gpu._validate_gpu_params("3.5.0", standalone_bad_conf) + classifer_on_gpu._validate_gpu_params("3.5.1", standalone_bad_conf) + + # Yarn and K8s mode + for mode in ["yarn", "k8s://"]: + conf = ( + SparkConf() + .setMaster(mode) + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + .set("spark.task.resource.gpu.amount", "0.08") + ) + with pytest.raises( + ValueError, + match="XGBoost doesn't support GPU fractional configurations", + ): + classifer_on_gpu._validate_gpu_params("3.3.0", conf) + with pytest.raises( + ValueError, + match="XGBoost doesn't support GPU fractional configurations", + ): + classifer_on_gpu._validate_gpu_params("3.4.0", conf) + with pytest.raises( + ValueError, + match="XGBoost doesn't support GPU fractional configurations", + ): + classifer_on_gpu._validate_gpu_params("3.4.1", conf) + with pytest.raises( + ValueError, + match="XGBoost doesn't support GPU fractional configurations", + ): + classifer_on_gpu._validate_gpu_params("3.5.0", conf) + + classifer_on_gpu._validate_gpu_params("3.5.1", conf) + + for mode in ["yarn", "k8s://"]: + bad_conf = ( + SparkConf() + .setMaster(mode) + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + ) + msg_match = ( + "The `spark.task.resource.gpu.amount` is required for training on GPU" + ) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.3.0", bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.4.0", bad_conf) + with pytest.raises(ValueError, match=msg_match): + classifer_on_gpu._validate_gpu_params("3.5.0", bad_conf) + + classifer_on_gpu._validate_gpu_params("3.5.1", bad_conf) + def test_skip_stage_level_scheduling(self) -> None: - conf = ( + standalone_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") @@ -943,26 +1062,36 @@ def test_skip_stage_level_scheduling(self) -> None: classifer_on_gpu = SparkXGBClassifier(use_gpu=True) # the correct configurations should not skip stage-level scheduling - assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", conf) + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.4.0", standalone_conf + ) + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.4.1", standalone_conf + ) + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.5.0", standalone_conf + ) + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.5.1", standalone_conf + ) # spark version < 3.4.0 - assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", conf) - + assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", standalone_conf) # not run on GPU - assert classifer_on_cpu._skip_stage_level_scheduling("3.4.0", conf) + assert classifer_on_cpu._skip_stage_level_scheduling("3.4.0", standalone_conf) # spark.executor.cores is not set - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.task.cpus", "1") .set("spark.executor.resource.gpu.amount", "1") .set("spark.task.resource.gpu.amount", "0.08") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.executor.cores=1 - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "1") @@ -970,20 +1099,20 @@ def test_skip_stage_level_scheduling(self) -> None: .set("spark.executor.resource.gpu.amount", "1") .set("spark.task.resource.gpu.amount", "0.08") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.executor.resource.gpu.amount is not set - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") .set("spark.task.cpus", "1") .set("spark.task.resource.gpu.amount", "0.08") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.executor.resource.gpu.amount>1 - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") @@ -991,20 +1120,20 @@ def test_skip_stage_level_scheduling(self) -> None: .set("spark.executor.resource.gpu.amount", "2") .set("spark.task.resource.gpu.amount", "0.08") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.task.resource.gpu.amount is not set - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") .set("spark.task.cpus", "1") .set("spark.executor.resource.gpu.amount", "1") ) - assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert not classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) # spark.task.resource.gpu.amount=1 - badConf = ( + bad_conf = ( SparkConf() .setMaster("spark://foo") .set("spark.executor.cores", "12") @@ -1012,29 +1141,32 @@ def test_skip_stage_level_scheduling(self) -> None: .set("spark.executor.resource.gpu.amount", "1") .set("spark.task.resource.gpu.amount", "1") ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) - - # yarn - badConf = ( - SparkConf() - .setMaster("yarn") - .set("spark.executor.cores", "12") - .set("spark.task.cpus", "1") - .set("spark.executor.resource.gpu.amount", "1") - .set("spark.task.resource.gpu.amount", "1") - ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", bad_conf) + + # For Yarn and K8S + for mode in ["yarn", "k8s://"]: + for gpu_amount in ["0.08", "0.2", "1.0"]: + conf = ( + SparkConf() + .setMaster(mode) + .set("spark.executor.cores", "12") + .set("spark.task.cpus", "1") + .set("spark.executor.resource.gpu.amount", "1") + .set("spark.task.resource.gpu.amount", gpu_amount) + ) + assert classifer_on_gpu._skip_stage_level_scheduling("3.3.0", conf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", conf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.4.1", conf) + assert classifer_on_gpu._skip_stage_level_scheduling("3.5.0", conf) - # k8s - badConf = ( - SparkConf() - .setMaster("k8s://") - .set("spark.executor.cores", "12") - .set("spark.task.cpus", "1") - .set("spark.executor.resource.gpu.amount", "1") - .set("spark.task.resource.gpu.amount", "1") - ) - assert classifer_on_gpu._skip_stage_level_scheduling("3.4.0", badConf) + # This will be fixed when spark 4.0.0 is released. + if gpu_amount == "1.0": + assert classifer_on_gpu._skip_stage_level_scheduling("3.5.1", conf) + else: + # Starting from 3.5.1+, stage-level scheduling is working for Yarn and K8s + assert not classifer_on_gpu._skip_stage_level_scheduling( + "3.5.1", conf + ) class XgboostLocalTest(SparkTestCase): @@ -1521,9 +1653,9 @@ def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]: [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [9.0, 4.0, 8.0], - [np.NaN, 1.0, 5.5], - [np.NaN, 6.0, 7.5], - [np.NaN, 8.0, 9.5], + [np.nan, 1.0, 5.5], + [np.nan, 6.0, 7.5], + [np.nan, 8.0, 9.5], ] ) qid_train = np.array([0, 0, 0, 1, 1, 1]) @@ -1534,9 +1666,9 @@ def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]: [1.5, 2.0, 3.0], [4.5, 5.0, 6.0], [9.0, 4.5, 8.0], - [np.NaN, 1.0, 6.0], - [np.NaN, 6.0, 7.0], - [np.NaN, 8.0, 10.5], + [np.nan, 1.0, 6.0], + [np.nan, 6.0, 7.0], + [np.nan, 8.0, 10.5], ] )