Skip to content

Commit

Permalink
Implement distributed unfold operation (#1419)
Browse files Browse the repository at this point in the history
* implemented the easy cases and a simple test

* general case

* exception handling, added test with two unfold (2D slices)

* added unfold to manipulations module

* added test

* fixed behavior for empty unfold_loc, exception handling for size - 1 > chunk_size
more tests

* wrong exception type in test

* fixed wrong exception type in tests

* fixed test for single node setting

* added better docstring

* added test to cover case that there are no fully local unfolds for a node

* fixed test case of no fully local unfolds

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed error due to unspecified torch device

* added tests with different datatypes

* renamed ´dimension´ to ´axis´

* use `DNDarray.counts_displs()´

* updated docstring

* use sanitize_axis

* support one-sided halo

* use `DNDarray.array_with_halos`

* fixed condition for empty local unfold data

* more tests

* detach after cloning

* test: blocking send in get_halo()

* replaced Send by Isend in "next"

* int64 in batchparallel clustering predict

* added error for size=1

* Update batchparallelclustering.py

Undid my stupid change before that belongs to another issue

* Removed old/dead code, resolved review

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com>
Co-authored-by: Fabian Hoppe <112093564+mrfh92@users.noreply.github.com>
Co-authored-by: Hoppe <mrhf92@gmail.com>
  • Loading branch information
5 people authored Aug 19, 2024
1 parent ef97474 commit 2ecf597
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 15 deletions.
38 changes: 23 additions & 15 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,18 @@ def __prephalo(self, start, end) -> torch.Tensor:

return self.__array[ix].clone().contiguous()

def get_halo(self, halo_size: int) -> torch.Tensor:
def get_halo(self, halo_size: int, prev: bool = True, next: bool = True) -> torch.Tensor:
"""
Fetch halos of size ``halo_size`` from neighboring ranks and save them in ``self.halo_next/self.halo_prev``.
Parameters
----------
halo_size : int
Size of the halo.
prev : bool, optional
If True, fetch the halo from the previous rank. Default: True.
next : bool, optional
If True, fetch the halo from the next rank. Default: True.
"""
if not isinstance(halo_size, int):
raise TypeError(
Expand Down Expand Up @@ -433,25 +437,29 @@ def get_halo(self, halo_size: int) -> torch.Tensor:
req_list = []

# exchange data with next populated process
if rank != last_rank:
self.comm.Isend(a_next, next_rank)
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=next_rank))
if prev:
if rank != last_rank:
self.comm.Isend(a_next, next_rank)
if rank != first_rank:
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=prev_rank))

if rank != first_rank:
self.comm.Isend(a_prev, prev_rank)
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=prev_rank))
if next:
if rank != first_rank:
req_list.append(self.comm.Isend(a_prev, prev_rank))
if rank != last_rank:
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=next_rank))

for req in req_list:
req.Wait()

self.__halo_next = res_prev
self.__halo_prev = res_next
self.__halo_next = res_next
self.__halo_prev = res_prev
self.__ishalo = True

def __cat_halo(self) -> torch.Tensor:
Expand Down
90 changes: 90 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"unique",
"vsplit",
"vstack",
"unfold",
]


Expand Down Expand Up @@ -4213,3 +4214,92 @@ def mpi_topk(a, b, mpi_type):


MPI_TOPK = MPI.Op.Create(mpi_topk, commute=True)


def unfold(a: DNDarray, axis: int, size: int, step: int = 1):
"""
Returns a DNDarray which contains all slices of size `size` in the axis `axis`.
Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)
Parameters
----------
a : DNDarray
array to unfold
axis : int
axis in which unfolding happens
size : int
the size of each slice that is unfolded, must be greater than 1
step : int
the step between each slice, must be at least 1
Example:
```
>>> x = ht.arange(1., 8)
>>> x
DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e)
>>> ht.unfold(x, 0, 2, 1)
DNDarray([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.]], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.unfold(x, 0, 2, 2)
DNDarray([[1., 2.],
[3., 4.],
[5., 6.]], dtype=ht.float32, device=cpu:0, split=None)
```
Note
---------
You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis.
"""
if step < 1:
raise ValueError("step must be >= 1.")
if size <= 1:
raise ValueError("size must be > 1.")
axis = stride_tricks.sanitize_axis(a.shape, axis)
if size > a.shape[axis]:
raise ValueError(
f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}."
)

comm = a.comm
dev = a.device
tdev = dev.torch_device

if a.split is None or comm.size == 1 or a.split != axis: # early out
ret = factories.array(
a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm
)

return ret
else: # comm.size > 1 and split axis == unfold axis
# index range [0:sizedim-1-(size-1)] = [0:sizedim-size]
# --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1
# ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size)

if (size - 1 > a.lshape_map[:, axis]).any():
raise RuntimeError("Chunk-size needs to be at least size - 1.")
a.get_halo(size - 1, prev=False)

counts, displs = a.counts_displs()
displs = torch.tensor(displs, device=tdev)

# min local index in unfold axis
min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank]
if min_index >= a.lshape[axis] or (
comm.rank == comm.size - 1 and min_index + size > a.lshape[axis]
):
loc_unfold_shape = list(a.lshape)
loc_unfold_shape[axis] = 0
ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev)
else: # unfold has local data
ret_larray = a.array_with_halos[
axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis)
].unfold(axis, size, step)

ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm)

return ret
60 changes: 60 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3752,3 +3752,63 @@ def test_vstack(self):
b = ht.ones((12,), split=0)
res = ht.vstack((a, b))
self.assertEqual(res.shape, (2, 12))

def test_unfold(self):
dtypes = (ht.int, ht.float)

for dtype in dtypes: # test with different datatypes
# exceptions
n = 1000
x = ht.arange(n, dtype=dtype)
with self.assertRaises(ValueError): # size too small
ht.unfold(x, 0, 1, 1)
with self.assertRaises(ValueError): # step too small
ht.unfold(x, 0, 2, 0)
x.resplit_(0)
min_chunk_size = x.lshape_map[:, 0].min().item()
if min_chunk_size + 2 > n: # size too large
with self.assertRaises(ValueError):
ht.unfold(x, 0, min_chunk_size + 2)
else: # size too large for chunk_size
with self.assertRaises(RuntimeError):
ht.unfold(x, 0, min_chunk_size + 2)
with self.assertRaises(ValueError): # size too large
ht.unfold(x, 0, n + 1, 1)
ht.unfold(
x, 0, min_chunk_size, min_chunk_size + 1
) # no fully local unfolds on some nodes

# 2D sliding views
n = 100

x = torch.arange(n * n).reshape((n, n))
y = ht.array(x, dtype)
y.resplit_(0)

u = x.unfold(0, 3, 3)
u = u.unfold(1, 3, 3)
u = ht.array(u)
v = ht.unfold(y, 0, 3, 3)
v = ht.unfold(v, 1, 3, 3)

self.assertTrue(ht.equal(u, v))

# more dimensions, different split axes
n = 53
k = 3 # number of dimensions
shape = k * (n,)
size = n**k

x = torch.arange(size).reshape(shape)
_y = x.clone().detach()
y = ht.array(_y, dtype)

for split in (None, *range(k)):
y.resplit_(split)
for size in range(2, 9):
for step in range(1, 21):
for dimension in range(k):
u = ht.array(x.unfold(dimension, size, step))
v = ht.unfold(y, dimension, size, step)

self.assertTrue(ht.equal(u, v))

0 comments on commit 2ecf597

Please sign in to comment.