diff --git a/.github/pytorch-release-versions/pytorch-latest.txt b/.github/pytorch-release-versions/pytorch-latest.txt index feaae22bac..b50dd27dd9 100644 --- a/.github/pytorch-release-versions/pytorch-latest.txt +++ b/.github/pytorch-release-versions/pytorch-latest.txt @@ -1 +1 @@ -1.13.0 +1.13.1 diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml index c1abd3124d..7fef410249 100644 --- a/.github/release-drafter.yml +++ b/.github/release-drafter.yml @@ -34,7 +34,7 @@ categories: label: 'chore' - title: '🧪 Testing' label: 'testing' - + change-template: '- #$NUMBER $TITLE (by @$AUTHOR)' categorie-template: '### $TITLE' exclude-labels: diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 822a501a9a..9be27312dd 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,5 +1,5 @@ test: - image: nvidia/cuda:11.6.2-runtime-ubuntu20.04 + image: nvidia/cuda:11.7.1-runtime-ubuntu22.04 tags: - cuda - x86_64 @@ -9,7 +9,7 @@ test: - DEBIAN_FRONTEND=noninteractive apt -y install libopenmpi-dev openmpi-bin openmpi-doc - apt -y install libhdf5-openmpi-dev libpnetcdf-dev - pip install pytest coverage - - pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116 + - pip3 install torch torchvision torchaudio - pip install .[hdf5,netcdf] - COVERAGE_FILE=report/cov/coverage1 HEAT_TEST_USE_DEVICE=cpu mpirun --allow-run-as-root -n 1 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report1.xml heat/ - COVERAGE_FILE=report/cov/coverage2 HEAT_TEST_USE_DEVICE=gpu mpirun --allow-run-as-root -n 3 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report3.xml heat/ diff --git a/heat/core/communication.py b/heat/core/communication.py index ad58dae964..9aa71323da 100644 --- a/heat/core/communication.py +++ b/heat/core/communication.py @@ -240,7 +240,11 @@ def counts_displs_shape( @classmethod def mpi_type_and_elements_of( - cls, obj: Union[DNDarray, torch.Tensor], counts: Tuple[int], displs: Tuple[int] + cls, + obj: Union[DNDarray, torch.Tensor], + counts: Tuple[int], + displs: Tuple[int], + is_contiguous: Optional[bool], ) -> Tuple[MPI.Datatype, Tuple[int, ...]]: """ Determines the MPI data type and number of respective elements for the given tensor (:class:`~heat.core.dndarray.DNDarray` @@ -255,12 +259,18 @@ def mpi_type_and_elements_of( Optional counts arguments for variable MPI-calls (e.g. Alltoallv) displs : Tuple[ints,...], optional Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) + is_contiguous: bool + Information on global contiguity of the memory-distributed object. If `None`, it will be set to local contiguity via ``torch.Tensor.is_contiguous()``. # ToDo: The option to explicitely specify the counts and displacements to be send still needs propper implementation """ mpi_type, elements = cls.__mpi_type_mappings[obj.dtype], torch.numel(obj) - # simple case, continuous memory can be transmitted as is - if obj.is_contiguous(): + # simple case, contiguous memory can be transmitted as is + if is_contiguous is None: + # determine local contiguity + is_contiguous = obj.is_contiguous() + + if is_contiguous: if counts is None: return mpi_type, elements else: @@ -273,7 +283,7 @@ def mpi_type_and_elements_of( ), ) - # non-continuous memory, e.g. after a transpose, has to be packed in derived MPI types + # non-contiguous memory, e.g. after a transpose, has to be packed in derived MPI types elements = obj.shape[0] shape = obj.shape[1:] strides = [1] * len(shape) @@ -305,7 +315,11 @@ def as_mpi_memory(cls, obj) -> MPI.memory: @classmethod def as_buffer( - cls, obj: torch.Tensor, counts: Tuple[int] = None, displs: Tuple[int] = None + cls, + obj: torch.Tensor, + counts: Tuple[int] = None, + displs: Tuple[int] = None, + is_contiguous: Optional[bool] = None, ) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]: """ Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type. @@ -318,14 +332,16 @@ def as_buffer( Optional counts arguments for variable MPI-calls (e.g. Alltoallv) displs : Tuple[int,...], optional Optional displacements arguments for variable MPI-calls (e.g. Alltoallv) + is_contiguous: bool, optional + Optional information on global contiguity of the memory-distributed object. """ squ = False if not obj.is_contiguous() and obj.ndim == 1: # this makes the math work below this function. obj.unsqueeze_(-1) squ = True - mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs) + mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous) mpi_mem = cls.as_mpi_memory(obj) if squ: # the squeeze happens in the mpi_type_and_elements_of function in the case of a @@ -1037,7 +1053,6 @@ def __allgather_like( type(sendbuf) ) ) - # unpack the receive buffer if isinstance(recvbuf, tuple): recvbuf, recv_counts, recv_displs = recvbuf @@ -1053,17 +1068,18 @@ def __allgather_like( # keep a reference to the original buffer object original_recvbuf = recvbuf - + sbuf_is_contiguous, rbuf_is_contiguous = None, None # permute the send_axis order so that the split send_axis is the first to be transmitted if axis != 0: send_axis_permutation = list(range(sendbuf.ndimension())) send_axis_permutation[0], send_axis_permutation[axis] = axis, 0 sendbuf = sendbuf.permute(*send_axis_permutation) + sbuf_is_contiguous = False - if axis != 0: recv_axis_permutation = list(range(recvbuf.ndimension())) recv_axis_permutation[0], recv_axis_permutation[axis] = axis, 0 recvbuf = recvbuf.permute(*recv_axis_permutation) + rbuf_is_contiguous = False else: recv_axis_permutation = None @@ -1074,20 +1090,18 @@ def __allgather_like( if sendbuf is MPI.IN_PLACE or not isinstance(sendbuf, torch.Tensor): mpi_sendbuf = sbuf else: - mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs) + mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs, sbuf_is_contiguous) if send_counts is not None: mpi_sendbuf[1] = mpi_sendbuf[1][0][self.rank] if recvbuf is MPI.IN_PLACE or not isinstance(recvbuf, torch.Tensor): mpi_recvbuf = rbuf else: - mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs) + mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs, rbuf_is_contiguous) if recv_counts is None: mpi_recvbuf[1] //= self.size - # perform the scatter operation exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs) - return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation def Allgather( @@ -1260,7 +1274,7 @@ def __alltoall_like( # keep a reference to the original buffer object original_recvbuf = recvbuf - # Simple case, continuous buffers can be transmitted as is + # Simple case, contiguous buffers can be transmitted as is if send_axis < 2 and recv_axis < 2: send_axis_permutation = list(range(recvbuf.ndimension())) recv_axis_permutation = list(range(recvbuf.ndimension())) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 9ec0ea89e1..6e9d2c56ef 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1268,8 +1268,11 @@ def resplit_(self, axis: int = None): axis = sanitize_axis(self.shape, axis) # early out for unchanged content + if self.comm.size == 1: + self.__split = axis if axis == self.split: return self + if axis is None: gathered = torch.empty( self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index bc5d3e9e65..7a2776386b 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -510,6 +510,13 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: if b.dtype != c_type: b = c_type(b, device=b.device) + # early out for single-process setup, torch matmul + if a.comm.size == 1: + ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device) + if gpu_int_flag: + ret = og_type(ret, device=a.device) + return ret + if a.split is None and b.split is None: # matmul from torch if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit: # if either of A or B is a vector @@ -517,17 +524,17 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: if gpu_int_flag: ret = og_type(ret, device=a.device) return ret - else: - a.resplit_(0) - slice_0 = a.comm.chunk(a.shape, a.split)[2][0] - hold = a.larray @ b.larray - c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device) - c.larray[slice_0.start : slice_0.stop, :] += hold - c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) - if gpu_int_flag: - c = og_type(c, device=a.device) - return c + a.resplit_(0) + slice_0 = a.comm.chunk(a.shape, a.split)[2][0] + hold = a.larray @ b.larray + + c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device) + c.larray[slice_0.start : slice_0.stop, :] += hold + c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM) + if gpu_int_flag: + c = og_type(c, device=a.device) + return c # if they are vectors they need to be expanded to be the proper dimensions vector_flag = False # flag to run squeeze at the end of the function diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index a3cb827b84..08e0ac43dd 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -237,7 +237,6 @@ def test_inv(self): self.assertTupleEqual(ainv.shape, a.shape) self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) - # distributed a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=0) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) @@ -245,6 +244,7 @@ def test_inv(self): self.assertTupleEqual(ainv.shape, a.shape) self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) + ares = ht.array([[2.0, 2, 1], [3, 4, 1], [0, 1, -1]], split=1) a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=1) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) @@ -281,7 +281,7 @@ def test_inv(self): self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) # pivoting row change - ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double) / 3.0 + ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=0) / 3.0 a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=0) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) @@ -289,6 +289,7 @@ def test_inv(self): self.assertTupleEqual(ainv.shape, a.shape) self.assertTrue(ht.allclose(ainv, ares, atol=1e-6)) + ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=1) / 3.0 a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=1) ainv = ht.linalg.inv(a) self.assertEqual(ainv.split, a.split) @@ -365,9 +366,28 @@ def test_matmul(self): self.assertEqual(ret00.shape, (n, k)) self.assertEqual(ret00.dtype, ht.float) self.assertEqual(ret00.split, None) - self.assertEqual(a.split, 0) + if a.comm.size > 1: + self.assertEqual(a.split, 0) self.assertEqual(b.split, None) + # splits 0 None on 1 process + if a.comm.size == 1: + a = ht.ones((n, m), split=0) + b = ht.ones((j, k), split=None) + a[0] = ht.arange(1, m + 1) + a[:, -1] = ht.arange(1, n + 1) + b[0] = ht.arange(1, k + 1) + b[:, 0] = ht.arange(1, j + 1) + ret00 = ht.matmul(a, b, allow_resplit=True) + + self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1) + self.assertIsInstance(ret00, ht.DNDarray) + self.assertEqual(ret00.shape, (n, k)) + self.assertEqual(ret00.dtype, ht.float) + self.assertEqual(ret00.split, None) + self.assertEqual(a.split, 0) + self.assertEqual(b.split, None) + if a.comm.size > 1: # splits 00 a = ht.ones((n, m), split=0, dtype=ht.float64) diff --git a/heat/core/logical.py b/heat/core/logical.py index a6be081ea7..8106a556ee 100644 --- a/heat/core/logical.py +++ b/heat/core/logical.py @@ -140,7 +140,19 @@ def allclose( t1, t2 = __sanitize_close_input(x, y) # no sanitation for shapes of x and y needed, torch.allclose raises relevant errors - _local_allclose = torch.tensor(torch.allclose(t1.larray, t2.larray, rtol, atol, equal_nan)) + try: + _local_allclose = torch.tensor(torch.allclose(t1.larray, t2.larray, rtol, atol, equal_nan)) + except RuntimeError: + promoted_dtype = torch.promote_types(t1.larray.dtype, t2.larray.dtype) + _local_allclose = torch.tensor( + torch.allclose( + t1.larray.type(promoted_dtype), + t2.larray.type(promoted_dtype), + rtol, + atol, + equal_nan, + ) + ) # If x is distributed, then y is also distributed along the same axis if t1.comm.is_distributed(): diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 33ebf4d365..7cf02ab016 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3372,6 +3372,9 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray: # early out for unchanged content if axis == arr.split: return arr.copy() + if not arr.is_distributed(): + return factories.array(arr.larray, split=axis, device=arr.device, copy=True) + if axis is None: # new_arr = arr.copy() gathered = torch.empty( @@ -3381,11 +3384,6 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray: arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split) new_arr = factories.array(gathered, is_split=axis, device=arr.device, dtype=arr.dtype) return new_arr - # tensor needs be split/sliced locally - if arr.split is None: - temp = arr.larray[arr.comm.chunk(arr.shape, axis)[2]] - new_arr = factories.array(temp, is_split=axis, device=arr.device, dtype=arr.dtype) - return new_arr arr_tiles = tiling.SplitTiles(arr) new_arr = factories.empty(arr.gshape, split=axis, dtype=arr.dtype, device=arr.device) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index e42c5a9a14..726a85e77a 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -126,7 +126,6 @@ def test_gethalo(self): # test no data on process data_np = np.arange(2 * 12).reshape(2, 12) data = ht.array(data_np, split=0) - print("DEBUGGING: data.lshape_map = ", data.lshape_map) data.get_halo(1) data_with_halos = data.array_with_halos diff --git a/heat/core/tests/test_logical.py b/heat/core/tests/test_logical.py index 691df7ec62..c2e3d1a786 100644 --- a/heat/core/tests/test_logical.py +++ b/heat/core/tests/test_logical.py @@ -182,6 +182,7 @@ def test_allclose(self): c = ht.zeros((4, 6), split=0) d = ht.zeros((4, 6), split=1) e = ht.zeros((4, 6)) + f = ht.float64([[2.000005, 2.000005], [2.000005, 2.000005]]) self.assertFalse(ht.allclose(a, b)) self.assertTrue(ht.allclose(a, b, atol=1e-04)) @@ -189,6 +190,7 @@ def test_allclose(self): self.assertTrue(ht.allclose(a, 2)) self.assertTrue(ht.allclose(a, 2.0)) self.assertTrue(ht.allclose(2, a)) + self.assertTrue(ht.allclose(f, a)) self.assertTrue(ht.allclose(c, d)) self.assertTrue(ht.allclose(c, e)) self.assertTrue(e.allclose(c)) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 9a41bceab8..4464053fd3 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -2992,6 +2992,16 @@ def test_resplit(self): self.assertEqual(data2.lshape, (data.comm.size, 1)) self.assertEqual(data2.split, 1) + # resplitting a non-distributed DNDarray with split not None + if ht.MPI_WORLD.size == 1: + data = ht.zeros(10, 10, split=0) + data2 = ht.resplit(data, 1) + data3 = ht.resplit(data, None) + self.assertTrue((data == data2).all()) + self.assertTrue((data == data3).all()) + self.assertEqual(data2.split, 1) + self.assertTrue(data3.split is None) + # splitting an unsplit tensor should result in slicing the tensor locally shape = (ht.MPI_WORLD.size, ht.MPI_WORLD.size) data = ht.zeros(shape) diff --git a/heat/core/tests/test_suites/basic_test.py b/heat/core/tests/test_suites/basic_test.py index f094668bc8..39f6a5f063 100644 --- a/heat/core/tests/test_suites/basic_test.py +++ b/heat/core/tests/test_suites/basic_test.py @@ -136,8 +136,12 @@ def assert_array_equal(self, heat_array, expected_array): "Local shapes do not match. " "Got {} expected {}".format(heat_array.lshape, expected_array[slices].shape), ) - local_heat_numpy = heat_array.numpy() - self.assertTrue(np.allclose(local_heat_numpy, expected_array)) + # compare local tensors to corresponding slice of expected_array + is_allclose = np.allclose(heat_array.larray.cpu(), expected_array[slices]) + ht_is_allclose = ht.array( + [is_allclose], dtype=ht.bool, is_split=0, device=heat_array.device + ) + self.assertTrue(ht.all(ht_is_allclose)) def assert_func_equal( self, diff --git a/setup.py b/setup.py index 2210ceaf97..0e8f00b0de 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ install_requires=[ "mpi4py>=3.0.0", "numpy>=1.13.0", - "torch>=1.7.0, <1.13.1", + "torch>=1.7.0, <1.13.2", "scipy>=0.14.0", "pillow>=6.0.0", "torchvision>=0.8.0",