-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add GitHub action workflow for Bazel CUDA continuous tests
PiperOrigin-RevId: 705238265
- Loading branch information
1 parent
57b2154
commit 2950e7f
Showing
11 changed files
with
529 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
name: CI - Bazel CUDA tests (Non RBE) | ||
|
||
on: | ||
workflow_dispatch: | ||
inputs: | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: choice | ||
required: true | ||
default: 'no' | ||
options: | ||
- 'yes' | ||
- 'no' | ||
workflow_call: | ||
inputs: | ||
runner: | ||
description: "Which runner should the workflow run on?" | ||
type: string | ||
required: true | ||
default: "linux-x86-n2-16" | ||
python: | ||
description: "Which python version to test?" | ||
type: string | ||
required: true | ||
default: "3.12" | ||
enable-x64: | ||
description: "Should x64 mode be enabled?" | ||
type: string | ||
required: true | ||
default: "0" | ||
gcs_download_uri: | ||
description: "GCS location URI from where the artifacts should be downloaded" | ||
required: true | ||
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' | ||
type: string | ||
|
||
jobs: | ||
run-tests: | ||
runs-on: ${{ inputs.runner }} | ||
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-rbe:latest" | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} | ||
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} | ||
|
||
name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }}" | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
# Halt for testing | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
- name: Set env vars for use in artifact download URL | ||
run: | | ||
os=$(uname -s | awk '{print tolower($0)}') | ||
arch=$(uname -m) | ||
# Get the major and minor version of Python. | ||
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 | ||
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') | ||
echo "OS=${os}" >> $GITHUB_ENV | ||
echo "ARCH=${arch}" >> $GITHUB_ENV | ||
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV | ||
- name: Download the wheel artifacts from GCS | ||
run: >- | ||
mkdir -p $(pwd)/dist && | ||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && | ||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && | ||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/ | ||
- name: Run Bazel CUDA tests (non RBE) | ||
run: ./ci/run_bazel_test_cuda_non_rbe.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# CI - Build JAX Artifacts | ||
# This workflow builds JAX wheels (jax, jaxlib, jax-cuda-plugin, and jax-cuda-pjrt) and optionally | ||
# uploads them to a Google Cloud Storage (GCS) bucket. It can be triggered manually via | ||
# workflow_dispatch or called by other workflows via workflow_call. | ||
name: CI - Build JAX Artifacts | ||
|
||
on: | ||
workflow_dispatch: | ||
inputs: | ||
runner: | ||
description: "Which runner should the workflow run on?" | ||
type: choice | ||
required: true | ||
default: "linux-x86-n2-16" | ||
options: | ||
- "linux-x86-n2-16" | ||
- "linux-arm64-c4a-64" | ||
- "windows-x86-n2-16" | ||
artifact: | ||
description: "Which JAX artifact to build?" | ||
type: choice | ||
required: true | ||
default: "jaxlib" | ||
options: | ||
- "jax" | ||
- "jaxlib" | ||
- "jax-cuda-plugin" | ||
- "jax-cuda-pjrt" | ||
python: | ||
description: "Which python version should the artifact be built for?" | ||
type: choice | ||
required: false | ||
default: "3.12" | ||
options: | ||
- "3.10" | ||
- "3.11" | ||
- "3.12" | ||
- "3.13" | ||
clone_main_xla: | ||
description: "Should latest XLA be used?" | ||
type: choice | ||
required: false | ||
default: "0" | ||
options: | ||
- "1" | ||
- "0" | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: choice | ||
required: false | ||
default: 'no' | ||
options: | ||
- 'yes' | ||
- 'no' | ||
workflow_call: | ||
inputs: | ||
runner: | ||
description: "Which runner should the workflow run on?" | ||
type: string | ||
required: true | ||
default: "linux-x86-n2-16" | ||
artifact: | ||
description: "Which JAX artifact to build?" | ||
type: string | ||
required: true | ||
default: "jaxlib" | ||
python: | ||
description: "Which python version should the artifact be built for?" | ||
type: string | ||
required: false | ||
default: "3.12" | ||
clone_main_xla: | ||
description: "Should latest XLA be used?" | ||
type: string | ||
required: false | ||
default: "0" | ||
upload_artifacts: | ||
description: "Should the artifacts be uploaded to a GCS bucket?" | ||
required: true | ||
default: true | ||
type: boolean | ||
gcs_upload_uri: | ||
description: "GCS location prefix to where the artifacts should be uploaded" | ||
required: true | ||
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}' | ||
type: string | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
build-artifacts: | ||
defaults: | ||
run: | ||
# Explicitly set the shell to bash to override Windows's default (cmd) | ||
shell: bash | ||
|
||
runs-on: ${{ inputs.runner }} | ||
|
||
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | ||
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | ||
(contains(inputs.runner, 'windows-x86') && null) }} | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" | ||
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" | ||
|
||
name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}) | ||
|
||
steps: | ||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
- name: Enable RBE if building on Linux x86 or Windows x86 | ||
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') | ||
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV | ||
# Halt for testing | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
- name: Build ${{ inputs.artifact }} | ||
run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" | ||
- name: Upload artifacts to a GCS bucket (non-Windows runs) | ||
if: >- | ||
${{ inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') }} | ||
run: gsutil -m cp -r $(pwd)/dist/*.whl "${{ inputs.gcs_upload_uri }}"/ | ||
# Set shell to cmd to avoid path errors when using gcloud commands on Windows | ||
- name: Upload artifacts to a GCS bucket (Windows runs) | ||
if: >- | ||
${{ inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') }} | ||
shell: cmd | ||
run: gsutil -m cp -r dist/*.whl "${{ inputs.gcs_upload_uri }}"/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# CI - Pytest CPU | ||
# | ||
# This workflow runs the CPU tests with Pytest. It can only be triggered by other workflows via | ||
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest CPU tests. | ||
# | ||
# It consists of the following job: | ||
# run-tests: | ||
# - Downloads the jaxlib wheel from a GCS bucket. | ||
# - Executes the `run_pytest_cpu.sh` script, which performs the following actions: | ||
# - Installs the downloaded jaxlib wheel. | ||
# - Runs the CPU tests with Pytest. | ||
name: CI - Pytest CPU | ||
|
||
on: | ||
workflow_call: | ||
inputs: | ||
runner: | ||
description: "Which runner should the workflow run on?" | ||
type: string | ||
required: true | ||
default: "linux-x86-n2-16" | ||
python: | ||
description: "Which python version should the artifact be built for?" | ||
type: string | ||
required: true | ||
default: "3.12" | ||
enable-x64: | ||
description: "Should x64 mode be enabled?" | ||
type: string | ||
required: true | ||
default: "0" | ||
gcs_download_uri: | ||
description: "GCS location prefix from where the artifacts should be downloaded" | ||
required: true | ||
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}' | ||
type: string | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: boolean | ||
required: false | ||
default: false | ||
|
||
jobs: | ||
run-tests: | ||
defaults: | ||
run: | ||
# Explicitly set the shell to bash to override Windows's default (cmd) | ||
shell: bash | ||
runs-on: ${{ inputs.runner }} | ||
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | ||
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | ||
(contains(inputs.runner, 'windows-x86') && null) }} | ||
|
||
name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" | ||
JAXCI_PYTHON: "python${{ inputs.python }}" | ||
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" | ||
|
||
steps: | ||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
# Halt for testing | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
- name: Set env vars for use in artifact download URL | ||
run: | | ||
os=$(uname -s | awk '{print tolower($0)}') | ||
arch=$(uname -m) | ||
# Adjust name for Windows | ||
if [[ $os =~ "msys_nt" ]]; then | ||
os="windows" | ||
fi | ||
# Get the major and minor version of Python. | ||
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 | ||
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') | ||
echo "OS=${os}" >> $GITHUB_ENV | ||
echo "ARCH=${arch}" >> $GITHUB_ENV | ||
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV | ||
- name: Download jaxlib wheel from GCS (non-Windows runs) | ||
if: ${{ !contains(matrix.runner, 'windows-x86') }} | ||
run: >- | ||
mkdir -p $(pwd)/dist && | ||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ | ||
- name: Download jaxlib wheel from GCS (Windows runs) | ||
if: ${{ contains(matrix.runner, 'windows-x86') }} | ||
shell: cmd | ||
run: >- | ||
mkdir dist && | ||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl dist/ | ||
- name: Install Python dependencies | ||
run: $JAXCI_PYTHON -m pip install -r build/requirements.in | ||
- name: Run Pytest CPU tests | ||
run: ./ci/run_pytest_cpu.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# CI - Pytest CUDA | ||
# | ||
# This workflow runs the CUDA tests with Pytest. It can only be triggered by other workflows via | ||
# `workflow_call`. It is used by the `CI - Wheel Tests` workflows to run the Pytest CUDA tests. | ||
# | ||
# It consists of the following job: | ||
# run-tests: | ||
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket. | ||
# - Executes the `run_pytest_cuda.sh` script, which performs the following actions: | ||
# - Installs the downloaded jaxlib wheel. | ||
# - Runs the CUDA tests with Pytest. | ||
name: CI - Pytest CUDA | ||
|
||
on: | ||
workflow_call: | ||
inputs: | ||
runner: | ||
description: "Which runner should the workflow run on?" | ||
type: string | ||
required: true | ||
default: "linux-x86-n2-16" | ||
python: | ||
description: "Which python version to test?" | ||
type: string | ||
required: true | ||
default: "3.12" | ||
cuda: | ||
description: "Which CUDA version to test?" | ||
type: string | ||
required: true | ||
default: "12.3" | ||
enable-x64: | ||
description: "Should x64 mode be enabled?" | ||
type: string | ||
required: true | ||
default: "0" | ||
gcs_download_uri: | ||
description: "GCS location prefix from where the artifacts should be downloaded" | ||
required: true | ||
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}' | ||
type: string | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: boolean | ||
required: false | ||
default: false | ||
|
||
jobs: | ||
run-tests: | ||
runs-on: ${{ inputs.runner }} | ||
# TODO: Update to the generic ML ecosystem test containers when they are ready. | ||
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') || | ||
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') }} | ||
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" | ||
JAXCI_PYTHON: "python${{ inputs.python }}" | ||
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" | ||
|
||
steps: | ||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
# Halt for testing | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
- name: Set env vars for use in artifact download URL | ||
run: | | ||
os=$(uname -s | awk '{print tolower($0)}') | ||
arch=$(uname -m) | ||
# Get the major and minor version of Python. | ||
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 | ||
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') | ||
echo "OS=${os}" >> $GITHUB_ENV | ||
echo "ARCH=${arch}" >> $GITHUB_ENV | ||
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV | ||
- name: Download the wheel artifacts from GCS | ||
run: >- | ||
mkdir -p $(pwd)/dist && | ||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && | ||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ && | ||
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/ | ||
- name: Install Python dependencies | ||
run: $JAXCI_PYTHON -m pip install -r build/requirements.in | ||
- name: Run Pytest CUDA tests | ||
run: ./ci/run_pytest_cuda.sh |
Oops, something went wrong.