Skip to content

Commit

Permalink
Support PyTorch 2.3.0 (#1467)
Browse files Browse the repository at this point in the history
* Support latest PyTorch release

* Update ci.yaml

* pytorch: add v2.2, drop v1.11

* fix type bug

* delete legacy code

* ci: explicit version numbers

---------

Co-authored-by: ClaudiaComito <39374113+ClaudiaComito@users.noreply.github.com>
Co-authored-by: Michael Tarnawa <m.tarnawa@fz-juelich.de>
  • Loading branch information
3 people authored Jun 5, 2024
1 parent c8a8de5 commit ee0d72a
Show file tree
Hide file tree
Showing 7 changed files with 6 additions and 45 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ jobs:
mpi: [ 'openmpi' ]
install-options: [ '.', '.[hdf5,netcdf]' ]
pytorch-version:
- 'torch==1.11.0+cpu torchvision==0.12.0+cpu torchaudio==0.11.0'
- 'torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1'
- 'torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1'
- 'torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2'
- 'torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2'
- 'torch torchvision torchaudio'
- 'torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2'
- 'torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0'
exclude:
- py-version: '3.11'
pytorch-version: 'torch==1.11.0+cpu torchvision==0.12.0+cpu torchaudio==0.11.0'
- py-version: '3.11'
pytorch-version: 'torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1'
- py-version: '3.11'
Expand Down
18 changes: 2 additions & 16 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,14 +1954,7 @@ def left_shift(
elif dtypes[dt] == types.bool:
arrs[dt] = types.int(arrs[dt])

try:
result = _operations.__binary_op(torch.bitwise_left_shift, t1, t2, out, where)
except AttributeError: # pragma: no cover
result = _operations.__binary_op(
torch.Tensor.__lshift__, t1, t2, out, where
) # pytorch < 1.10

return result
return _operations.__binary_op(torch.bitwise_left_shift, t1, t2, out, where)


def _lshift(self, other):
Expand Down Expand Up @@ -2875,14 +2868,7 @@ def right_shift(
elif dtypes[dt] == types.bool:
arrs[dt] = types.int(arrs[dt])

try:
result = _operations.__binary_op(torch.bitwise_right_shift, t1, t2, out, where)
except AttributeError: # pragma: no cover
result = _operations.__binary_op(
torch.Tensor.__rshift__, t1, t2, out, where
) # pytorch < 1.10

return result
return _operations.__binary_op(torch.bitwise_right_shift, t1, t2, out, where)


def _rshift(self, other):
Expand Down
2 changes: 1 addition & 1 deletion heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def load_hdf5(
data = handle[dataset]
gshape = data.shape
if split is not None:
gshape = np.array(gshape)
gshape = list(gshape)
gshape[split] = int(gshape[split] * load_fraction)
gshape = tuple(gshape)
dims = len(gshape)
Expand Down
Binary file removed heat/core/linalg/.DS_Store
Binary file not shown.
13 changes: 0 additions & 13 deletions heat/core/stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,6 @@ def broadcast_shape(shape_a: Tuple[int, ...], shape_b: Tuple[int, ...]) -> Tuple
"""
try:
resulting_shape = torch.broadcast_shapes(shape_a, shape_b)
except AttributeError: # torch < 1.8
it = itertools.zip_longest(shape_a[::-1], shape_b[::-1], fillvalue=1)
resulting_shape = max(len(shape_a), len(shape_b)) * [None]
for i, (a, b) in enumerate(it):
if a == 0 and b == 1 or b == 0 and a == 1:
resulting_shape[i] = 0
elif a == 1 or b == 1 or a == b:
resulting_shape[i] = max(a, b)
else:
raise ValueError(
f"operands could not be broadcast, input shapes {shape_a} {shape_b}"
)
return tuple(resulting_shape[::-1])
except TypeError:
raise TypeError(f"operand 1 must be tuple of ints, not {type(shape_a)}")
except NameError:
Expand Down
10 changes: 0 additions & 10 deletions heat/fft/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,11 +858,6 @@ def ihfft2(
-----
This function requires MPI communication if the input array is distributed and the split axis is transformed.
"""
torch_has_ihfftn = hasattr(torch.fft, "ihfftn")
if not torch_has_ihfftn: # pragma: no cover
raise NotImplementedError(
f"n-dim inverse Hermitian FFTs not implemented for torch < 1.11.0. Your environment runs torch {torch.__version__}. Please upgrade torch."
)
return __real_fftn_op(x, torch.fft.ihfft2, s=s, axes=axes, norm=norm)


Expand Down Expand Up @@ -896,11 +891,6 @@ def ihfftn(
-----
This function requires MPI communication if the input array is distributed and the split axis is transformed.
"""
torch_has_ihfftn = hasattr(torch.fft, "ihfftn")
if not torch_has_ihfftn: # pragma: no cover
raise NotImplementedError(
f"n-dim inverse Hermitian FFTs not implemented for torch < 1.11.0. Your environment runs torch {torch.__version__}. Please upgrade torch."
)
return __real_fftn_op(x, torch.fft.ihfftn, s=s, axes=axes, norm=norm)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
install_requires=[
"mpi4py>=3.0.0",
"numpy>=1.22.0",
"torch>=1.11.0, <2.2.3",
"torch>=1.12.0, <2.3.1",
"scipy>=1.10.0",
"pillow>=6.0.0",
"torchvision>=0.12.0",
Expand Down

0 comments on commit ee0d72a

Please sign in to comment.