Skip to content

Commit

Permalink
ci: pin PT to 2.3.1 when using CUDA (#4009)
Browse files Browse the repository at this point in the history
PT 2.4.0 requires cudnn 9, incompatible with the latest TF with cudnn 8.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Updated the build process to optimize resource usage by removing
unnecessary files.
- Specified a fixed version for PyTorch to ensure consistent
functionality across environments.
	
- **Documentation**
- Added configuration settings to manage library dependencies and ensure
compatibility.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Jul 24, 2024
1 parent a6ea2c1 commit 9e14d45
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ jobs:
- variant: "_cu11"
cuda_version: "11"
steps:
- name: Delete huge unnecessary tools folder
run: rm -rf /opt/hostedtoolcache
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
&& sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3
if: false # skip as we use nvidia image
- run: python -m pip install -U uv
- run: source/install/uv_with_retry.sh pip install --system "tensorflow>=2.15.0rc0" "torch>=2.2.0"
- run: source/install/uv_with_retry.sh pip install --system "tensorflow>=2.15.0rc0" "torch==2.3.1.*"
- run: |
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ PATH = "/usr/lib64/mpich/bin:$PATH"
UV_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cpu"
# trick to find the correction version of mpich
CMAKE_PREFIX_PATH="/opt/python/cp311-cp311/"
# PT 2.4.0 requires cudnn 9, incompatible with TF with cudnn 8
PYTORCH_VERSION = "2.3.1"

[tool.cibuildwheel.windows]
test-extras = ["cpu", "torch"]
Expand Down

0 comments on commit 9e14d45

Please sign in to comment.