Skip to content

Commit

Permalink
Fix pytorch versions ci (#3289)
Browse files Browse the repository at this point in the history
* Fixed torch.pi usage unavailable in pytorch 1.5.1, 1.8.1

* Fix pytorch versions CI failures
  • Loading branch information
vfdev-5 authored Oct 2, 2024
1 parent 6a1ef07 commit 302c707
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
14 changes: 8 additions & 6 deletions .github/workflows/pytorch-version-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@ jobs:
# will drop python version and related pytorch versions
python-version: [3.8, 3.9, "3.10"]
pytorch-version:
[2.3.1, 2.2.2, 2.1.2, 2.0.1, 1.13.1, 1.12.1, 1.10.0, 1.8.1, 1.5.1]
[2.3.1, 2.2.2, 2.1.2, 2.0.1, 1.13.1, 1.12.1, 1.10.0, 1.8.1]
exclude:
- pytorch-version: 1.5.1
python-version: 3.9
- pytorch-version: 1.5.1
python-version: "3.10"

# disabling python 3.9 support with PyTorch 1.7.1 and 1.8.1, to stop repeated pytorch-version test fail.
# https://github.com/pytorch/ignite/issues/2383
- pytorch-version: 1.8.1
Expand Down Expand Up @@ -74,6 +69,13 @@ jobs:
shell: bash -l {0}
run: |
conda install pytorch=${{ matrix.pytorch-version }} torchvision cpuonly python=${{ matrix.python-version }} -c pytorch
# We should install numpy<2.0 for pytorch<2.3
numpy_one_pth_version=$(python -c "import torch; print(float('.'.join(torch.__version__.split('.')[:2])) < 2.3)")
if [ "${numpy_one_pth_version}" == "True" ]; then
pip install -U "numpy<2.0"
fi
pip install -r requirements-dev.txt
python setup.py install
Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/hsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def update(self, output: Sequence[Tensor]) -> None:

vx: Union[Tensor, float]
if self.sigma_x < 0:
# vx = torch.quantile(dxx, 0.5)
vx = torch.quantile(dxx, 0.5)
else:
vx = self.sigma_x**2
Expand Down
3 changes: 2 additions & 1 deletion tests/ignite/metrics/test_hsic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -74,7 +75,7 @@ def test_case(request) -> Tuple[Tensor, Tensor, int]:
N = 200
b = 20
x = torch.randn(N, 5)
y = x @ torch.normal(0.0, torch.pi, size=(5, 3))
y = x @ torch.normal(0.0, math.pi, size=(5, 3))
y = (
torch.stack([torch.sin(y[:, 0]), torch.cos(y[:, 1]), torch.exp(y[:, 2])], dim=1)
+ torch.randn_like(y) * 1e-4
Expand Down

0 comments on commit 302c707

Please sign in to comment.