Add Github action workflows for running continuous tests with Pytest #4
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: CI - Build JAX Artifacts | |
on: | |
pull_request: | |
branches: | |
- main | |
workflow_dispatch: | |
inputs: | |
halt-for-connection: | |
description: 'Should this workflow run wait for a remote connection?' | |
type: choice | |
required: true | |
default: 'no' | |
options: | |
- 'yes' | |
- 'no' | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: true | |
jobs: | |
build-artifacts: | |
if: github.event.repository.fork == false | |
defaults: | |
run: | |
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. | |
shell: bash | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
runner: ["windows-x86-n2-16", "linux-x86-n2-16", "linux-arm64-t2a-48"] | |
artifact: ["jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] | |
python: ["3.10", "3.11", "3.12", "3.13"] | |
exclude: | |
# Don't build jax-cuda-pjrt and jax-cuda-plugin on windows-x86-n2-16 | |
- runner: "windows-x86-n2-16" | |
artifact: "jax-cuda-pjrt" | |
- runner: "windows-x86-n2-16" | |
artifact: "jax-cuda-plugin" | |
# Don't build jax-cuda-pjrt for each python version | |
- artifact: "jax-cuda-pjrt" | |
python: 3.10 | |
- artifact: "jax-cuda-pjrt" | |
python: 3.11 | |
- artifact: "jax-cuda-pjrt" | |
python: 3.12 | |
runs-on: ${{ matrix.runner }} | |
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | |
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | |
(contains(matrix.runner, 'windows-x86') && null) }} | |
env: | |
JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" | |
name: Build ${{ matrix.artifact }} on ${{ matrix.runner }} with Python ${{ matrix.python }} | |
steps: | |
- uses: actions/checkout@v3 | |
- name: Enable RBE on platforms where its supported | |
run: | | |
os=$(uname -s | awk '{print tolower($0)}') | |
arch=$(uname -m) | |
# Enable RBE if building on Linux x86 or Windows x86 | |
if [[ ($os == "linux" || $os =~ "msys_nt" ) && $arch == "x86_64" ]]; then | |
echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV | |
fi | |
# Halt for testing | |
- name: Wait For Connection | |
uses: google-ml-infra/actions/ci_connection@main | |
with: | |
halt-dispatch-input: ${{ inputs.halt-for-connection }} | |
- name: Build ${{ matrix.artifact }} | |
run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" |