Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge-case contiguity mismatch for Allgatherv #1058

Merged
merged 27 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f46ae67
Fix edge-case contiguity mismatch for Allgatherv
ClaudiaComito Dec 12, 2022
4da69fd
merge branch release/1.2.x
ClaudiaComito Dec 12, 2022
27ea911
Update ubuntu
ClaudiaComito Dec 12, 2022
d0fb6c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2022
0e704d4
switch back to ubuntu 20.04
ClaudiaComito Dec 12, 2022
f5d7850
pull
ClaudiaComito Dec 12, 2022
acfe9bd
Upgrade CI to ubuntu 22.04 and cuda 11.7.1
ClaudiaComito Dec 12, 2022
0fd3d87
avoid unnecessary gathering of test DNDarrays
ClaudiaComito Dec 20, 2022
3c4c07c
early out for resplit of non-distributed DNDarrays
ClaudiaComito Dec 20, 2022
989e0f4
match split of comparison array to expected output
ClaudiaComito Dec 20, 2022
6d66fad
avoid MPI calls in non-distributed cases
ClaudiaComito Dec 20, 2022
a37b4d3
avoid MPI calls in non-distributed resplit
ClaudiaComito Dec 20, 2022
8eebe10
set default to None
ClaudiaComito Dec 20, 2022
22c5c68
remove print statement
ClaudiaComito Dec 20, 2022
c692bff
upgrade torch version
ClaudiaComito Dec 20, 2022
df6a4e5
copy to cpu before comparing
ClaudiaComito Dec 20, 2022
af0e721
use ht.allclose instead of np.allclose
ClaudiaComito Dec 23, 2022
bac6d4e
cast different dtype operands to promoted dtype within torch call
ClaudiaComito Dec 23, 2022
c0c6362
compare local tensors to corresponding slice of expected_array only
ClaudiaComito Dec 23, 2022
587bc05
expand tests
ClaudiaComito Dec 23, 2022
24239a1
remove redundant code
ClaudiaComito Dec 23, 2022
38c00a3
use pytorch with cuda117 support
mtar Jan 3, 2023
5d25588
[skip ci] Update heat/core/communication.py
ClaudiaComito Jan 10, 2023
b382d4b
[skip ci] Update heat/core/communication.py
ClaudiaComito Jan 10, 2023
26d92bf
[skip ci] Update heat/core/communication.py
ClaudiaComito Jan 10, 2023
79e13e2
Remove dead code
ClaudiaComito Jan 15, 2023
d62a64d
Update pytorch-latest.txt
mtar Jan 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/release-drafter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ categories:
label: 'chore'
- title: '🧪 Testing'
label: 'testing'

change-template: '- #$NUMBER $TITLE (by @$AUTHOR)'
categorie-template: '### $TITLE'
exclude-labels:
Expand Down
4 changes: 2 additions & 2 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test:
image: nvidia/cuda:11.6.2-runtime-ubuntu20.04
image: nvidia/cuda:11.7.1-runtime-ubuntu22.04
tags:
- cuda
- x86_64
Expand All @@ -9,7 +9,7 @@ test:
- DEBIAN_FRONTEND=noninteractive apt -y install libopenmpi-dev openmpi-bin openmpi-doc
- apt -y install libhdf5-openmpi-dev libpnetcdf-dev
- pip install pytest coverage
- pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
- pip3 install torch torchvision torchaudio
- pip install .[hdf5,netcdf]
- COVERAGE_FILE=report/cov/coverage1 HEAT_TEST_USE_DEVICE=cpu mpirun --allow-run-as-root -n 1 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report1.xml heat/
- COVERAGE_FILE=report/cov/coverage2 HEAT_TEST_USE_DEVICE=gpu mpirun --allow-run-as-root -n 3 coverage run --source=heat --parallel-mode -m pytest --junitxml=report/test/report3.xml heat/
Expand Down
42 changes: 28 additions & 14 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ def counts_displs_shape(

@classmethod
def mpi_type_and_elements_of(
cls, obj: Union[DNDarray, torch.Tensor], counts: Tuple[int], displs: Tuple[int]
cls,
obj: Union[DNDarray, torch.Tensor],
counts: Tuple[int],
displs: Tuple[int],
is_contiguous: bool,
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[MPI.Datatype, Tuple[int, ...]]:
"""
Determines the MPI data type and number of respective elements for the given tensor (:class:`~heat.core.dndarray.DNDarray`
Expand All @@ -255,12 +259,18 @@ def mpi_type_and_elements_of(
Optional counts arguments for variable MPI-calls (e.g. Alltoallv)
displs : Tuple[ints,...], optional
Optional displacements arguments for variable MPI-calls (e.g. Alltoallv)
is_contiguous: bool, optional
Optional information on global contiguity of the memory-distributed object. If `None`, it will be set to local contiguity via ``torch.Tensor.is_contiguous()``.
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
# ToDo: The option to explicitely specify the counts and displacements to be send still needs propper implementation
"""
mpi_type, elements = cls.__mpi_type_mappings[obj.dtype], torch.numel(obj)

# simple case, continuous memory can be transmitted as is
if obj.is_contiguous():
# simple case, contiguous memory can be transmitted as is
if is_contiguous is None:
# determine local contiguity
is_contiguous = obj.is_contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the value is different on the processes. How likely is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mtar, that's a great question. The obvious case in which this might happen is the permutation, and this is dealt with in this PR. Outside of that, we're simply falling back to the previous implementation.

I could add a global check that sets is_contiguous to False if the local contiguities are dishomogeneous.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @ClaudiaComito

You added extra lines for arrays with different splits when the number of MPI processes is one. It doesn't make much sense to pass an argument when the split won't take place on one process. What do you think about disallowing splits or automatically setting he value to None when only a single process is involved at array creation time? It would save us some tests/checks.

This is a general discussion worth having, maybe not re: this bug fix.
My main argument against setting all splits to None when running on 1 MPI process, is that it will be confusing for users while they are testing their code (potentially on 1 process or even interactively).

Anyway, let's discuss it in a separate Issue.

As far as I'm concerned, I'm done with this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mtar, that's a great question. The obvious case in which this might happen is the permutation, and this is dealt with in this PR. Outside of that, we're simply falling back to the previous implementation.

I could add a global check that sets is_contiguous to False if the local contiguities are dishomogeneous.

I've decided not to add (yet another) global check for contiguous status for now, as I can't think of the appropriate edge-case to test it. We are already testing for column-first memory layout operations. If anybody can think of something, let me know.


if is_contiguous:
if counts is None:
return mpi_type, elements
else:
Expand All @@ -273,7 +283,7 @@ def mpi_type_and_elements_of(
),
)

# non-continuous memory, e.g. after a transpose, has to be packed in derived MPI types
# non-contiguous memory, e.g. after a transpose, has to be packed in derived MPI types
elements = obj.shape[0]
shape = obj.shape[1:]
strides = [1] * len(shape)
Expand Down Expand Up @@ -305,7 +315,11 @@ def as_mpi_memory(cls, obj) -> MPI.memory:

@classmethod
def as_buffer(
cls, obj: torch.Tensor, counts: Tuple[int] = None, displs: Tuple[int] = None
cls,
obj: torch.Tensor,
counts: Tuple[int] = None,
displs: Tuple[int] = None,
is_contiguous: bool = None,
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
) -> List[Union[MPI.memory, Tuple[int, int], MPI.Datatype]]:
"""
Converts a passed ``torch.Tensor`` into a memory buffer object with associated number of elements and MPI data type.
Expand All @@ -318,14 +332,16 @@ def as_buffer(
Optional counts arguments for variable MPI-calls (e.g. Alltoallv)
displs : Tuple[int,...], optional
Optional displacements arguments for variable MPI-calls (e.g. Alltoallv)
is_contiguous: bool, optional
Optional information on global contiguity of the memory-distributed object.
"""
squ = False
if not obj.is_contiguous() and obj.ndim == 1:
# this makes the math work below this function.
obj.unsqueeze_(-1)
squ = True
mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs)

mpi_type, elements = cls.mpi_type_and_elements_of(obj, counts, displs, is_contiguous)
mpi_mem = cls.as_mpi_memory(obj)
if squ:
# the squeeze happens in the mpi_type_and_elements_of function in the case of a
Expand Down Expand Up @@ -1037,7 +1053,6 @@ def __allgather_like(
type(sendbuf)
)
)

# unpack the receive buffer
if isinstance(recvbuf, tuple):
recvbuf, recv_counts, recv_displs = recvbuf
Expand All @@ -1053,17 +1068,18 @@ def __allgather_like(

# keep a reference to the original buffer object
original_recvbuf = recvbuf

sbuf_is_contiguous, rbuf_is_contiguous = None, None
# permute the send_axis order so that the split send_axis is the first to be transmitted
if axis != 0:
send_axis_permutation = list(range(sendbuf.ndimension()))
send_axis_permutation[0], send_axis_permutation[axis] = axis, 0
sendbuf = sendbuf.permute(*send_axis_permutation)
sbuf_is_contiguous = False

if axis != 0:
recv_axis_permutation = list(range(recvbuf.ndimension()))
recv_axis_permutation[0], recv_axis_permutation[axis] = axis, 0
recvbuf = recvbuf.permute(*recv_axis_permutation)
rbuf_is_contiguous = False
else:
recv_axis_permutation = None

Expand All @@ -1074,20 +1090,18 @@ def __allgather_like(
if sendbuf is MPI.IN_PLACE or not isinstance(sendbuf, torch.Tensor):
mpi_sendbuf = sbuf
else:
mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs)
mpi_sendbuf = self.as_buffer(sbuf, send_counts, send_displs, sbuf_is_contiguous)
if send_counts is not None:
mpi_sendbuf[1] = mpi_sendbuf[1][0][self.rank]

if recvbuf is MPI.IN_PLACE or not isinstance(recvbuf, torch.Tensor):
mpi_recvbuf = rbuf
else:
mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs)
mpi_recvbuf = self.as_buffer(rbuf, recv_counts, recv_displs, rbuf_is_contiguous)
if recv_counts is None:
mpi_recvbuf[1] //= self.size

# perform the scatter operation
exit_code = func(mpi_sendbuf, mpi_recvbuf, **kwargs)

return exit_code, sbuf, rbuf, original_recvbuf, recv_axis_permutation

def Allgather(
Expand Down Expand Up @@ -1260,7 +1274,7 @@ def __alltoall_like(
# keep a reference to the original buffer object
original_recvbuf = recvbuf

# Simple case, continuous buffers can be transmitted as is
# Simple case, contiguous buffers can be transmitted as is
if send_axis < 2 and recv_axis < 2:
send_axis_permutation = list(range(recvbuf.ndimension()))
recv_axis_permutation = list(range(recvbuf.ndimension()))
Expand Down
3 changes: 3 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,8 +1268,11 @@ def resplit_(self, axis: int = None):
axis = sanitize_axis(self.shape, axis)

# early out for unchanged content
if self.comm.size == 1:
self.__split = axis
if axis == self.split:
return self

if axis is None:
gathered = torch.empty(
self.shape, dtype=self.dtype.torch_type(), device=self.device.torch_device
Expand Down
27 changes: 17 additions & 10 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,24 +510,31 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
if b.dtype != c_type:
b = c_type(b, device=b.device)

# early out for single-process setup, torch matmul
if a.comm.size == 1:
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret

if a.split is None and b.split is None: # matmul from torch
if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit:
# if either of A or B is a vector
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
else:
a.resplit_(0)
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a.larray @ b.larray

c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device)
c.larray[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
a.resplit_(0)
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a.larray @ b.larray

c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device)
c.larray[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c

# if they are vectors they need to be expanded to be the proper dimensions
vector_flag = False # flag to run squeeze at the end of the function
Expand Down
26 changes: 24 additions & 2 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,15 @@ def test_inv(self):
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

# distributed
# ares = ht.array([[2.0, 2, 1], [3, 4, 1], [0, 1, -1]], split=0)
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=0)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
self.assertEqual(ainv.device, a.device)
self.assertTupleEqual(ainv.shape, a.shape)
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

ares = ht.array([[2.0, 2, 1], [3, 4, 1], [0, 1, -1]], split=1)
a = ht.array([[5.0, -3, 2], [-3, 2, -1], [-3, 2, -2]], split=1)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
Expand Down Expand Up @@ -281,14 +283,15 @@ def test_inv(self):
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

# pivoting row change
ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double) / 3.0
ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=0) / 3.0
a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=0)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
self.assertEqual(ainv.device, a.device)
self.assertTupleEqual(ainv.shape, a.shape)
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=ht.double, split=1) / 3.0
a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=ht.double, split=1)
ainv = ht.linalg.inv(a)
self.assertEqual(ainv.split, a.split)
Expand Down Expand Up @@ -365,9 +368,28 @@ def test_matmul(self):
self.assertEqual(ret00.shape, (n, k))
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, None)
self.assertEqual(a.split, 0)
if a.comm.size > 1:
self.assertEqual(a.split, 0)
self.assertEqual(b.split, None)

# splits 0 None on 1 process
if a.comm.size == 1:
a = ht.ones((n, m), split=0)
b = ht.ones((j, k), split=None)
a[0] = ht.arange(1, m + 1)
a[:, -1] = ht.arange(1, n + 1)
b[0] = ht.arange(1, k + 1)
b[:, 0] = ht.arange(1, j + 1)
ret00 = ht.matmul(a, b, allow_resplit=True)

self.assertEqual(ht.all(ret00 == ht.array(a_torch @ b_torch)), 1)
self.assertIsInstance(ret00, ht.DNDarray)
self.assertEqual(ret00.shape, (n, k))
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, None)
self.assertEqual(a.split, 0)
self.assertEqual(b.split, None)

if a.comm.size > 1:
# splits 00
a = ht.ones((n, m), split=0, dtype=ht.float64)
Expand Down
14 changes: 13 additions & 1 deletion heat/core/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,19 @@ def allclose(
t1, t2 = __sanitize_close_input(x, y)

# no sanitation for shapes of x and y needed, torch.allclose raises relevant errors
_local_allclose = torch.tensor(torch.allclose(t1.larray, t2.larray, rtol, atol, equal_nan))
try:
_local_allclose = torch.tensor(torch.allclose(t1.larray, t2.larray, rtol, atol, equal_nan))
except RuntimeError:
promoted_dtype = torch.promote_types(t1.larray.dtype, t2.larray.dtype)
_local_allclose = torch.tensor(
torch.allclose(
t1.larray.type(promoted_dtype),
t2.larray.type(promoted_dtype),
rtol,
atol,
equal_nan,
)
)

# If x is distributed, then y is also distributed along the same axis
if t1.comm.is_distributed():
Expand Down
8 changes: 3 additions & 5 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3372,6 +3372,9 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
# early out for unchanged content
if axis == arr.split:
return arr.copy()
if not arr.is_distributed():
return factories.array(arr.larray, split=axis, device=arr.device, copy=True)

if axis is None:
# new_arr = arr.copy()
gathered = torch.empty(
Expand All @@ -3381,11 +3384,6 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split)
new_arr = factories.array(gathered, is_split=axis, device=arr.device, dtype=arr.dtype)
return new_arr
# tensor needs be split/sliced locally
if arr.split is None:
temp = arr.larray[arr.comm.chunk(arr.shape, axis)[2]]
new_arr = factories.array(temp, is_split=axis, device=arr.device, dtype=arr.dtype)
return new_arr

arr_tiles = tiling.SplitTiles(arr)
new_arr = factories.empty(arr.gshape, split=axis, dtype=arr.dtype, device=arr.device)
Expand Down
1 change: 0 additions & 1 deletion heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def test_gethalo(self):
# test no data on process
data_np = np.arange(2 * 12).reshape(2, 12)
data = ht.array(data_np, split=0)
print("DEBUGGING: data.lshape_map = ", data.lshape_map)
data.get_halo(1)

data_with_halos = data.array_with_halos
Expand Down
2 changes: 2 additions & 0 deletions heat/core/tests/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,15 @@ def test_allclose(self):
c = ht.zeros((4, 6), split=0)
d = ht.zeros((4, 6), split=1)
e = ht.zeros((4, 6))
f = ht.float64([[2.000005, 2.000005], [2.000005, 2.000005]])

self.assertFalse(ht.allclose(a, b))
self.assertTrue(ht.allclose(a, b, atol=1e-04))
self.assertTrue(ht.allclose(a, b, rtol=1e-04))
self.assertTrue(ht.allclose(a, 2))
self.assertTrue(ht.allclose(a, 2.0))
self.assertTrue(ht.allclose(2, a))
self.assertTrue(ht.allclose(f, a))
self.assertTrue(ht.allclose(c, d))
self.assertTrue(ht.allclose(c, e))
self.assertTrue(e.allclose(c))
Expand Down
10 changes: 10 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2992,6 +2992,16 @@ def test_resplit(self):
self.assertEqual(data2.lshape, (data.comm.size, 1))
self.assertEqual(data2.split, 1)

# resplitting a non-distributed DNDarray with split not None
if ht.MPI_WORLD.size == 1:
data = ht.zeros(10, 10, split=0)
data2 = ht.resplit(data, 1)
data3 = ht.resplit(data, None)
self.assertTrue((data == data2).all())
self.assertTrue((data == data3).all())
self.assertEqual(data2.split, 1)
self.assertTrue(data3.split is None)

# splitting an unsplit tensor should result in slicing the tensor locally
shape = (ht.MPI_WORLD.size, ht.MPI_WORLD.size)
data = ht.zeros(shape)
Expand Down
8 changes: 6 additions & 2 deletions heat/core/tests/test_suites/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ def assert_array_equal(self, heat_array, expected_array):
"Local shapes do not match. "
"Got {} expected {}".format(heat_array.lshape, expected_array[slices].shape),
)
local_heat_numpy = heat_array.numpy()
self.assertTrue(np.allclose(local_heat_numpy, expected_array))
# compare local tensors to corresponding slice of expected_array
is_allclose = np.allclose(heat_array.larray.cpu(), expected_array[slices])
ht_is_allclose = ht.array(
[is_allclose], dtype=ht.bool, is_split=0, device=heat_array.device
)
self.assertTrue(ht.all(ht_is_allclose))

def assert_func_equal(
self,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
install_requires=[
"mpi4py>=3.0.0",
"numpy>=1.13.0",
"torch>=1.7.0, <1.13.1",
"torch>=1.7.0, <1.13.2",
"scipy>=0.14.0",
"pillow>=6.0.0",
"torchvision>=0.8.0",
Expand Down