Add GitHub action workflow for Bazel CUDA continuous tests #4
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CI - Wheel Tests (Continuous) | |
# | |
# This workflow builds JAX artifacts and runs CPU/CUDA tests. | |
# | |
# It orchestrates the following: | |
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and | |
# uploads it to a GCS bucket. | |
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow to download the jaxlib wheel that was built | |
# in the previous step and runs CPU tests. | |
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and | |
# uploads them to a GCS bucket. | |
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow to download the jaxlib and CUDA artifacts | |
# that were built in the previous steps and runs the CUDA tests. | |
# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow to download the jaxlib and | |
# CUDA artifacts that were built in the previous steps and runs the CUDA | |
# tests using Bazel. | |
name: CI - Wheel Tests (Continuous) | |
on: | |
# schedule: | |
# - cron: "0 */2 * * *" # Run once every 2 hours | |
# TODO: For testing purposes, remove pull_request event before submitting | |
pull_request: | |
branches: | |
- main | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: true | |
jobs: | |
build-jaxlib-artifact: | |
uses: ./.github/workflows/build_artifacts.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Runner OS and Python values need to match the matrix stategy in the CPU tests job | |
# Enable Windows after we have fixed the runner issue | |
runner: ["linux-x86-n2-16", "linux-arm64-c4a-64",] # "windows-x86-n2-16"] | |
artifact: ["jaxlib"] | |
python: ["3.10"] | |
with: | |
runner: ${{ matrix.runner }} | |
artifact: ${{ matrix.artifact }} | |
python: ${{ matrix.python }} | |
clone_main_xla: 1 | |
upload_artifacts: true | |
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}' | |
build-cuda-artifacts: | |
uses: ./.github/workflows/build_artifacts.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the GPU tests job below | |
runner: ["linux-x86-n2-16"] | |
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] | |
python: ["3.10",] | |
with: | |
runner: ${{ matrix.runner }} | |
artifact: ${{ matrix.artifact }} | |
python: ${{ matrix.python }} | |
clone_main_xla: 1 | |
upload_artifacts: true | |
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}' | |
run-pytest-cpu: | |
needs: build-jaxlib-artifact | |
uses: ./.github/workflows/pytest_cpu.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Runner OS and Python values need to match the matrix stategy in the | |
# build_jaxlib_artifact job above | |
runner: ["linux-x86-n2-64", "linux-arm64-c4a-64",] | |
python: ["3.10",] | |
enable-x64: [1, 0] | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}' | |
run-pytest-cuda: | |
needs: [build-jaxlib-artifact, build-cuda-artifacts] | |
uses: ./.github/workflows/pytest_cuda.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the artifact build jobs above | |
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"] | |
python: ["3.10",] | |
cuda: ["12.3", "12.1"] | |
enable-x64: [1, 0] | |
exclude: | |
# Run only a single configuration on H100 to save resources | |
- runner: "linux-x86-a3-8g-h100-8gpu" | |
python: "3.10" | |
cuda: "12.1" | |
- runner: "linux-x86-a3-8g-h100-8gpu" | |
python: "3.10" | |
enable-x64: 0 | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
cuda: ${{ matrix.cuda }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}' | |
run-bazel-test-cuda: | |
needs: [build-jaxlib-artifact, build-cuda-artifacts] | |
uses: ./.github/workflows/bazel_cuda_non_rbe.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the build artifacts job above | |
runner: ["linux-x86-g2-48-l4-4gpu",] | |
python: ["3.10",] | |
enable-x64: [1, 0] | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' |