Skip to content

Commit

Permalink
Fix edge-case contiguity mismatch for Allgatherv (#1058)
Browse files Browse the repository at this point in the history
* Fix edge-case contiguity mismatch for Allgatherv

* Update ubuntu

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* switch back to ubuntu 20.04

* Upgrade CI to ubuntu 22.04 and cuda 11.7.1

* avoid unnecessary gathering of test DNDarrays

* early out for resplit of non-distributed DNDarrays

* match split of comparison array to expected output

* avoid MPI calls in non-distributed cases

* avoid MPI calls in non-distributed resplit

* set  default to None

* remove print statement

* upgrade torch version

* copy to cpu before comparing

* use ht.allclose instead of np.allclose

* cast different dtype operands to promoted dtype within torch call

* compare local tensors to corresponding slice of expected_array only

* expand tests

* remove redundant code

* use pytorch with cuda117 support

* [skip ci] Update heat/core/communication.py

Co-authored-by: mtar <m.tarnawa@fz-juelich.de>

* [skip ci] Update heat/core/communication.py

Co-authored-by: mtar <m.tarnawa@fz-juelich.de>

* [skip ci]  Update heat/core/communication.py

Co-authored-by: mtar <m.tarnawa@fz-juelich.de>

* Remove dead code

* Update pytorch-latest.txt

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mtar <m.tarnawa@fz-juelich.de>
  • Loading branch information
3 people authored Jan 19, 2023
1 parent 54db506 commit 73e6204
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/release-drafter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ categories:
label: 'chore'
- title: '🧪 Testing'
label: 'testing'

change-template: '- #$NUMBER $TITLE (by @$AUTHOR)'
categorie-template: '### $TITLE'
exclude-labels:
Expand Down
4 changes: 2 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test:
image: nvidia/cuda:11.6.2-runtime-ubuntu20.04
image: nvidia/cuda:11.7.1-runtime-ubuntu22.04
tags:
- cuda
- x86_64
Expand All @@ -9,7 +9,7 @@ test:
- DEBIAN_FRONTEND=noninteractive apt -y install libopenmpi-dev openmpi-bin openmpi-doc
- apt -y install libhdf5-openmpi-dev libpnetcdf-dev
- pip install pytest coverage
- pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
- pip3 install torch torchvision torchaudio
- pip install .[hdf5,netcdf]
- COVERAGE_FILE=report/cov/coverage1 HEAT_TEST_USE_DEVICE=cpu mpirun --allow-run-as-root -n 1 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report1.xml heat/
- COVERAGE_FILE=report/cov/coverage2 HEAT_TEST_USE_DEVICE=gpu mpirun --allow-run-as-root -n 3 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report3.xml heat/
Expand Down
42 changes: 28 additions & 14 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ def counts_displs_shape(

@classmethod
def mpi_type_and_elements_of(
cls, obj: Union[DNDarray, torch.Tensor], counts: Tuple[int], displs: Tuple[int]
cls,
obj: Union[DNDarray, torch.Tensor],
counts: Tuple[int],
displs: Tuple[int],
is_contiguous: Optional[bool],
) -> Tuple[MPI.Datatype, Tuple[int, ...]]:
"""
Determines the MPI data type and number of respective elements for the given tensor (:class:`~heat.core.dndarray.DNDarray`
Expand All @@ -255,12 +259,18 @@ def mpi_type_and_elements_of(
Optional counts arguments for variable MPI-calls (e.g. Alltoallv)
displs : Tuple[ints,...], optional
Optional displacements arguments for variable MPI-calls (e.g. Alltoallv)
is_contiguous: bool
Information on global contiguity of the memory-distributed object. If `None`, it will be set to local contiguity via ``torch.Tensor.is_contiguous()``.
# ToDo: The option to explicitely specify the counts and displacements to be send still needs propper implementation
"""
mpi_type, elements = cls.__mpi_type_mappings[obj.dtype], torch.numel(obj)

# simple case, continuous memory can be transmitted as is
if obj.is_contiguous():
# simple case, contiguous memory can be transmitted as is
if is_contiguous is None:
# determine local contiguity
is_contiguous = obj.is_contiguous()

if is_contiguous:
if counts is None:
return mpi_type, elements
else:
Expand All @@ -273,7 +283,7 @@ def mpi_type_and_elements_of(
),
)

# non-continuous memory, e.g. after a transpose, has to be packed in derived MPI types
# non-contiguous memory, e.g. after a transpose, has to be packed in derived MPI types
elements = obj.shape[0]
shape = obj.shape[1:]
strides = [1] * len(shape)
Expand Down Expand Up @@ -305,7 +315,11 @@ def as_mpi_memory(cls, obj) -> MPI.memory:

@classmethod
def as_buffer(
cls, obj: torch.Tensor, counts: Tuple[int] = None, displs: Tuple[int] = None
cls,
obj: torch.Tensor,
counts: Tuple[int] = None,
displs: Tuple[int] = None,
is_contiguous: Optional[bool] = None,
) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]:
"""
Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type.
Expand All @@ -318,14 +332,16 @@ def as_buffer(
Optional counts arguments for variable MPI-calls (e.g. Alltoallv)
displs : Tuple[int,...], optional
Optional displacements arguments for variable MPI-calls (e.g. Alltoallv)
is_contiguous: bool, optional
Optional information on global contiguity of the memory-distributed object.
"""
squ = False
if not obj.is_contiguous() and obj.ndim == 1:
# this makes the math work below this function.
obj.unsqueeze_(-1)
squ = True
mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs)

mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous)
mpi_mem = cls.as_mpi_memory(obj)
if squ:
# the squeeze happens in the mpi_type_and_elements_of function in the case of a
Expand Down Expand Up @@ -1037,7 +1053,6 @@ def __allgather_like(
type(sendbuf)
)
)

# unpack the receive buffer
if isinstance(recvbuf, tuple):
recvbuf, recv_counts, recv_displs = recvbuf
Expand All @@ -1053,17 +1068,18 @@ def __allgather_like(

# keep a reference to the original buffer object
original_recvbuf = recvbuf

sbuf_is_contiguous, rbuf_is_contiguous = None, None
# permute the send_axis order so that the split send_axis is the first to be transmitted
if axis != 0:
send_axis_permutation = list(range(sendbuf.ndimension()))
send_axis_permutation[0], send_axis_permutation[axis] = axis, 0
sendbuf = sendbuf.permute(*send_axis_permutation)
sbuf_is_contiguous = False

if axis != 0:
recv_axis_permutation = list(range(recvbuf.ndimension()))
recv_axis_permutation[0], recv_axis_permutation[axis] = axis, 0
recvbuf = recvbuf.permute(*recv_axis_permutation)
rbuf_is_contiguous = False
else:
recv_axis_permutation = None

Expand All @@ -1074,20 +1090,18 @@ def __allgather_like(
if sendbuf is MPI.IN_PLACE or not isinstance(sendbuf, torch.Tensor):
mpi_sendbuf = sbuf
else:
mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs)
mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs, sbuf_is_contiguous)
if send_counts is not None:
mpi_sendbuf[1] = mpi_sendbuf[1][0][self.rank]

if recvbuf is MPI.IN_PLACE or not isinstance(recvbuf, torch.Tensor):
mpi_recvbuf = rbuf
else:
mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs)
mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs, rbuf_is_contiguous)
if recv_counts is None:
mpi_recvbuf[1] //= self.size

# perform the scatter operation
exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs)

return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation

def Allgather(
Expand Down Expand Up @@ -1260,7 +1274,7 @@ def __alltoall_like(
# keep a reference to the original buffer object
original_recvbuf = recvbuf

# Simple case, continuous buffers can be transmitted as is
# Simple case, contiguous buffers can be transmitted as is
if send_axis < 2 and recv_axis < 2:
send_axis_permutation = list(range(recvbuf.ndimension()))
recv_axis_permutation = list(range(recvbuf.ndimension()))
Expand Down
3 changes: 3 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,8 +1268,11 @@ def resplit_(self, axis: int = None):
axis = sanitize_axis(self.shape, axis)

# early out for unchanged content
if self.comm.size == 1:
self.__split = axis
if axis == self.split:
return self

if axis is None:
gathered = torch.empty(
self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device
Expand Down
27 changes: 17 additions & 10 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,24 +510,31 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
if b.dtype != c_type:
b = c_type(b, device=b.device)

# early out for single-process setup, torch matmul
if a.comm.size == 1:
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret

if a.split is None and b.split is None: # matmul from torch
if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit:
# if either of A or B is a vector
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
else:
a.resplit_(0)
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a.larray @ b.larray

c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device)
c.larray[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
a.resplit_(0)
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a.larray @ b.larray

c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device)
c.larray[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c

# if they are vectors they need to be expanded to be the proper dimensions
vector_flag = False # flag to run squeeze at the end of the function
Expand Down
26 changes: 23 additions & 3 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,14 @@ def test_inv(self):
self.assertTupleEqual(ainv.shape, a.shape)
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

# distributed
a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=0)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
self.assertEqual(ainv.device, a.device)
self.assertTupleEqual(ainv.shape, a.shape)
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

ares = ht.array([[2.0, 2, 1], [3, 4, 1], [0, 1, -1]], split=1)
a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=1)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
Expand Down Expand Up @@ -281,14 +281,15 @@ def test_inv(self):
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

# pivoting row change
ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double) / 3.0
ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=0) / 3.0
a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=0)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
self.assertEqual(ainv.device, a.device)
self.assertTupleEqual(ainv.shape, a.shape)
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=1) / 3.0
a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=1)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
Expand Down Expand Up @@ -365,9 +366,28 @@ def test_matmul(self):
self.assertEqual(ret00.shape, (n, k))
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, None)
self.assertEqual(a.split, 0)
if a.comm.size > 1:
self.assertEqual(a.split, 0)
self.assertEqual(b.split, None)

# splits 0 None on 1 process
if a.comm.size == 1:
a = ht.ones((n, m), split=0)
b = ht.ones((j, k), split=None)
a[0] = ht.arange(1, m + 1)
a[:, -1] = ht.arange(1, n + 1)
b[0] = ht.arange(1, k + 1)
b[:, 0] = ht.arange(1, j + 1)
ret00 = ht.matmul(a, b, allow_resplit=True)

self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1)
self.assertIsInstance(ret00, ht.DNDarray)
self.assertEqual(ret00.shape, (n, k))
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, None)
self.assertEqual(a.split, 0)
self.assertEqual(b.split, None)

if a.comm.size > 1:
# splits 00
a = ht.ones((n, m), split=0, dtype=ht.float64)
Expand Down
14 changes: 13 additions & 1 deletion heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,19 @@ def allclose(
t1, t2 = __sanitize_close_input(x, y)

# no sanitation for shapes of x and y needed, torch.allclose raises relevant errors
_local_allclose = torch.tensor(torch.allclose(t1.larray, t2.larray, rtol, atol, equal_nan))
try:
_local_allclose = torch.tensor(torch.allclose(t1.larray, t2.larray, rtol, atol, equal_nan))
except RuntimeError:
promoted_dtype = torch.promote_types(t1.larray.dtype, t2.larray.dtype)
_local_allclose = torch.tensor(
torch.allclose(
t1.larray.type(promoted_dtype),
t2.larray.type(promoted_dtype),
rtol,
atol,
equal_nan,
)
)

# If x is distributed, then y is also distributed along the same axis
if t1.comm.is_distributed():
Expand Down
8 changes: 3 additions & 5 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3372,6 +3372,9 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
# early out for unchanged content
if axis == arr.split:
return arr.copy()
if not arr.is_distributed():
return factories.array(arr.larray, split=axis, device=arr.device, copy=True)

if axis is None:
# new_arr = arr.copy()
gathered = torch.empty(
Expand All @@ -3381,11 +3384,6 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split)
new_arr = factories.array(gathered, is_split=axis, device=arr.device, dtype=arr.dtype)
return new_arr
# tensor needs be split/sliced locally
if arr.split is None:
temp = arr.larray[arr.comm.chunk(arr.shape, axis)[2]]
new_arr = factories.array(temp, is_split=axis, device=arr.device, dtype=arr.dtype)
return new_arr

arr_tiles = tiling.SplitTiles(arr)
new_arr = factories.empty(arr.gshape, split=axis, dtype=arr.dtype, device=arr.device)
Expand Down
1 change: 0 additions & 1 deletion heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def test_gethalo(self):
# test no data on process
data_np = np.arange(2 * 12).reshape(2, 12)
data = ht.array(data_np, split=0)
print("DEBUGGING: data.lshape_map = ", data.lshape_map)
data.get_halo(1)

data_with_halos = data.array_with_halos
Expand Down
2 changes: 2 additions & 0 deletions heat/core/tests/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,15 @@ def test_allclose(self):
c = ht.zeros((4, 6), split=0)
d = ht.zeros((4, 6), split=1)
e = ht.zeros((4, 6))
f = ht.float64([[2.000005, 2.000005], [2.000005, 2.000005]])

self.assertFalse(ht.allclose(a, b))
self.assertTrue(ht.allclose(a, b, atol=1e-04))
self.assertTrue(ht.allclose(a, b, rtol=1e-04))
self.assertTrue(ht.allclose(a, 2))
self.assertTrue(ht.allclose(a, 2.0))
self.assertTrue(ht.allclose(2, a))
self.assertTrue(ht.allclose(f, a))
self.assertTrue(ht.allclose(c, d))
self.assertTrue(ht.allclose(c, e))
self.assertTrue(e.allclose(c))
Expand Down
10 changes: 10 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2992,6 +2992,16 @@ def test_resplit(self):
self.assertEqual(data2.lshape, (data.comm.size, 1))
self.assertEqual(data2.split, 1)

# resplitting a non-distributed DNDarray with split not None
if ht.MPI_WORLD.size == 1:
data = ht.zeros(10, 10, split=0)
data2 = ht.resplit(data, 1)
data3 = ht.resplit(data, None)
self.assertTrue((data == data2).all())
self.assertTrue((data == data3).all())
self.assertEqual(data2.split, 1)
self.assertTrue(data3.split is None)

# splitting an unsplit tensor should result in slicing the tensor locally
shape = (ht.MPI_WORLD.size, ht.MPI_WORLD.size)
data = ht.zeros(shape)
Expand Down
8 changes: 6 additions & 2 deletions heat/core/tests/test_suites/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ def assert_array_equal(self, heat_array, expected_array):
"Local shapes do not match. "
"Got {} expected {}".format(heat_array.lshape, expected_array[slices].shape),
)
local_heat_numpy = heat_array.numpy()
self.assertTrue(np.allclose(local_heat_numpy, expected_array))
# compare local tensors to corresponding slice of expected_array
is_allclose = np.allclose(heat_array.larray.cpu(), expected_array[slices])
ht_is_allclose = ht.array(
[is_allclose], dtype=ht.bool, is_split=0, device=heat_array.device
)
self.assertTrue(ht.all(ht_is_allclose))

def assert_func_equal(
self,
Expand Down

0 comments on commit 73e6204

Please sign in to comment.