From 4ae77a5286c106864dce978412cd6a759ee1e0c2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:16:00 +0200 Subject: [PATCH] Support PyTorch 2.3.0 (#1467) * Support latest PyTorch release * Update ci.yaml * pytorch: add v2.2, drop v1.11 * fix type bug * delete legacy code * ci: explicit version numbers --------- Co-authored-by: ClaudiaComito <39374113+ClaudiaComito@users.noreply.github.com> Co-authored-by: Michael Tarnawa (cherry picked from commit ee0d72a4754d6caea85daaacc4fb43fd16507cb1) --- .github/workflows/ci.yaml | 6 ++---- heat/core/arithmetics.py | 18 ++---------------- heat/core/io.py | 2 +- heat/core/linalg/.DS_Store | Bin 6148 -> 0 bytes heat/core/stride_tricks.py | 13 ------------- heat/fft/fft.py | 10 ---------- setup.py | 2 +- 7 files changed, 6 insertions(+), 45 deletions(-) delete mode 100644 heat/core/linalg/.DS_Store diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 796462c565..5fbc9c3664 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -19,15 +19,13 @@ jobs: mpi: [ 'openmpi' ] install-options: [ '.', '.[hdf5,netcdf]' ] pytorch-version: - - 'torch==1.11.0+cpu torchvision==0.12.0+cpu torchaudio==0.11.0' - 'torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1' - 'torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1' - 'torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2' - 'torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2' - - 'torch torchvision torchaudio' + - 'torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2' + - 'torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0' exclude: - - py-version: '3.11' - pytorch-version: 'torch==1.11.0+cpu torchvision==0.12.0+cpu torchaudio==0.11.0' - py-version: '3.11' pytorch-version: 'torch==1.12.1+cpu torchvision==0.13.1+cpu torchaudio==0.12.1' - py-version: '3.11' diff --git a/heat/core/arithmetics.py b/heat/core/arithmetics.py index a82afe8797..d91c6e6d1b 100644 --- a/heat/core/arithmetics.py +++ b/heat/core/arithmetics.py @@ -1954,14 +1954,7 @@ def left_shift( elif dtypes[dt] == types.bool: arrs[dt] = types.int(arrs[dt]) - try: - result = _operations.__binary_op(torch.bitwise_left_shift, t1, t2, out, where) - except AttributeError: # pragma: no cover - result = _operations.__binary_op( - torch.Tensor.__lshift__, t1, t2, out, where - ) # pytorch < 1.10 - - return result + return _operations.__binary_op(torch.bitwise_left_shift, t1, t2, out, where) def _lshift(self, other): @@ -2875,14 +2868,7 @@ def right_shift( elif dtypes[dt] == types.bool: arrs[dt] = types.int(arrs[dt]) - try: - result = _operations.__binary_op(torch.bitwise_right_shift, t1, t2, out, where) - except AttributeError: # pragma: no cover - result = _operations.__binary_op( - torch.Tensor.__rshift__, t1, t2, out, where - ) # pytorch < 1.10 - - return result + return _operations.__binary_op(torch.bitwise_right_shift, t1, t2, out, where) def _rshift(self, other): diff --git a/heat/core/io.py b/heat/core/io.py index 0d141b1c20..739e9ad0f9 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -134,7 +134,7 @@ def load_hdf5( data = handle[dataset] gshape = data.shape if split is not None: - gshape = np.array(gshape) + gshape = list(gshape) gshape[split] = int(gshape[split] * load_fraction) gshape = tuple(gshape) dims = len(gshape) diff --git a/heat/core/linalg/.DS_Store b/heat/core/linalg/.DS_Store deleted file mode 100644 index 4869d09464b09c335b19818e0984e357236c28bd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKF-`+P3>-s{BBe=5xnJN1t0=sH2OMw#Q6L5BNI`cU-)8J}1koi$8j5Hv*>mgl z?CPdCp8?ovx48qB0OoW@e0!Lh@4HXzDl?82XY_bNk2k!an^Yf9ICsJ74j=d;_m|(j zZifDOyM1@!4L*73z&PyFyiq0vq<|EV0#ZNcZKx;(q`;v9*KN7<{(nzDG5;Tuw2}f+;Gz_;+4^C<;wxouoxGg)+CqP%?~S!l n&JeAb7_FEKZ^gHN>WV*ezb1|igU)==iTV+6U1U<=KNR=@zDpWe diff --git a/heat/core/stride_tricks.py b/heat/core/stride_tricks.py index 266a901044..22e9fff694 100644 --- a/heat/core/stride_tricks.py +++ b/heat/core/stride_tricks.py @@ -44,19 +44,6 @@ def broadcast_shape(shape_a: Tuple[int, ...], shape_b: Tuple[int, ...]) -> Tuple """ try: resulting_shape = torch.broadcast_shapes(shape_a, shape_b) - except AttributeError: # torch < 1.8 - it = itertools.zip_longest(shape_a[::-1], shape_b[::-1], fillvalue=1) - resulting_shape = max(len(shape_a), len(shape_b)) * [None] - for i, (a, b) in enumerate(it): - if a == 0 and b == 1 or b == 0 and a == 1: - resulting_shape[i] = 0 - elif a == 1 or b == 1 or a == b: - resulting_shape[i] = max(a, b) - else: - raise ValueError( - f"operands could not be broadcast, input shapes {shape_a} {shape_b}" - ) - return tuple(resulting_shape[::-1]) except TypeError: raise TypeError(f"operand 1 must be tuple of ints, not {type(shape_a)}") except NameError: diff --git a/heat/fft/fft.py b/heat/fft/fft.py index c989da7ab0..324b02955e 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -858,11 +858,6 @@ def ihfft2( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ - torch_has_ihfftn = hasattr(torch.fft, "ihfftn") - if not torch_has_ihfftn: # pragma: no cover - raise NotImplementedError( - f"n-dim inverse Hermitian FFTs not implemented for torch < 1.11.0. Your environment runs torch {torch.__version__}. Please upgrade torch." - ) return __real_fftn_op(x, torch.fft.ihfft2, s=s, axes=axes, norm=norm) @@ -896,11 +891,6 @@ def ihfftn( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ - torch_has_ihfftn = hasattr(torch.fft, "ihfftn") - if not torch_has_ihfftn: # pragma: no cover - raise NotImplementedError( - f"n-dim inverse Hermitian FFTs not implemented for torch < 1.11.0. Your environment runs torch {torch.__version__}. Please upgrade torch." - ) return __real_fftn_op(x, torch.fft.ihfftn, s=s, axes=axes, norm=norm) diff --git a/setup.py b/setup.py index cb7e4fa0cd..7680d3ea6b 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ install_requires=[ "mpi4py>=3.0.0", "numpy>=1.22.0", - "torch>=1.11.0, <2.2.3", + "torch>=1.12.0, <2.3.1", "scipy>=1.10.0", "pillow>=6.0.0", "torchvision>=0.12.0",