Skip to content

Commit

Permalink
Merge pull request #541 from helmholtz-analytics/features/481-halos
Browse files Browse the repository at this point in the history
Features/481 halos
  • Loading branch information
krajsek authored Apr 24, 2020
2 parents 9da4f99 + 34d19ad commit f71b83f
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 3 deletions.
112 changes: 112 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def __init__(self, array, gshape, dtype, split, device, comm):
self.__split = split
self.__device = device
self.__comm = comm
self.__ishalo = False
self.__halo_next = None
self.__halo_prev = None

# handle inconsistencies between torch and heat devices
if (
Expand All @@ -81,6 +84,14 @@ def __init__(self, array, gshape, dtype, split, device, comm):
):
self.__array = self.__array.to(devices.sanitize_device(self.__device).torch_device)

@property
def halo_next(self):
return self.__halo_next

@property
def halo_prev(self):
return self.__halo_prev

@property
def comm(self):
return self.__comm
Expand Down Expand Up @@ -248,6 +259,107 @@ def strides(self):
def T(self):
return linalg.transpose(self, axes=None)

@property
def array_with_halos(self):
return self.__cat_halo()

def __prephalo(self, start, end):
"""
Extracts the halo indexed by start, end from self.array in the direction of self.split
Parameters
----------
start : int
start index of the halo extracted from self.array
end : int
end index of the halo extracted from self.array
Returns
-------
halo : torch.Tensor
The halo extracted from self.array
"""
ix = [slice(None, None, None)] * len(self.shape)
try:
ix[self.split] = slice(start, end)
except IndexError:
print("Indices out of bound")

return self.__array[ix].clone().contiguous()

def get_halo(self, halo_size):
"""
Fetch halos of size 'halo_size' from neighboring ranks and save them in self.halo_next/self.halo_prev
in case they are not already stored. If 'halo_size' differs from the size of already stored halos,
the are overwritten.
Parameters
----------
halo_size : int
Size of the halo.
"""
if not isinstance(halo_size, int):
raise TypeError(
"halo_size needs to be of Python type integer, {} given)".format(type(halo_size))
)
if halo_size < 0:
raise ValueError(
"halo_size needs to be a positive Python integer, {} given)".format(type(halo_size))
)

if self.comm.is_distributed() and self.split is not None:
min_chunksize = self.shape[self.split] // self.comm.size
if halo_size > min_chunksize:
raise ValueError(
"halo_size {} needs to smaller than chunck-size {} )".format(
halo_size, min_chunksize
)
)

a_prev = self.__prephalo(0, halo_size)
a_next = self.__prephalo(-halo_size, None)

res_prev = None
res_next = None

req_list = list()

if self.comm.rank != self.comm.size - 1:
self.comm.Isend(a_next, self.comm.rank + 1)
res_prev = torch.zeros(a_prev.size(), dtype=a_prev.dtype)
req_list.append(self.comm.Irecv(res_prev, source=self.comm.rank + 1))

if self.comm.rank != 0:
self.comm.Isend(a_prev, self.comm.rank - 1)
res_next = torch.zeros(a_next.size(), dtype=a_next.dtype)
req_list.append(self.comm.Irecv(res_next, source=self.comm.rank - 1))

for req in req_list:
req.wait()

self.__halo_next = res_prev
self.__halo_prev = res_next
self.__ishalo = True

def __cat_halo(self):
"""
Fetch halos of size 'halo_size' from neighboring ranks and save them in self.halo_next/self.halo_prev
in case they are not already stored. If 'halo_size' differs from the size of already stored halos,
the are overwritten.
Parameters
----------
None
Returns
-------
array + halos: pytorch tensors
"""
return torch.cat(
[_ for _ in (self.__halo_prev, self.__array, self.__halo_next) if _ is not None],
self.split,
)

def abs(self, out=None, dtype=None):
"""
Calculate the absolute value element-wise.
Expand Down
95 changes: 95 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,101 @@ def test_and(self):
ht.equal(int16_tensor & int16_vector, ht.bitwise_and(int16_tensor, int16_vector))
)

def test_gethalo(self):
data_np = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
data = ht.array(data_np, split=1)

if data.comm.size == 2:

halo_next = torch.tensor(np.array([[4, 5], [10, 11]]))
halo_prev = torch.tensor(np.array([[2, 3], [8, 9]]))

data.get_halo(2)

data_with_halos = data.array_with_halos
self.assertEqual(data_with_halos.shape, (2, 5))

if data.comm.rank == 0:
self.assertTrue(torch.equal(data.halo_next, halo_next))
self.assertEqual(data.halo_prev, None)
if data.comm.rank == 1:
self.assertTrue(torch.equal(data.halo_prev, halo_prev))
self.assertEqual(data.halo_next, None)

self.assertEqual(data.array_with_halos.shape, (2, 5))
# exception on wrong argument type in get_halo
with self.assertRaises(TypeError):
data.get_halo("wrong_type")
# exception on wrong argument in get_halo
with self.assertRaises(ValueError):
data.get_halo(-99)
# exception for too large halos
with self.assertRaises(ValueError):
data.get_halo(4)

data_np = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0, 11.0, 12.0]])
data = ht.array(data_np, split=1)

halo_next = torch.tensor(np.array([[4.0, 5.0], [10.0, 11.0]]))
halo_prev = torch.tensor(np.array([[2.0, 3.0], [8.0, 9.0]]))

data.get_halo(2)

if data.comm.rank == 0:
self.assertTrue(np.isclose(((data.halo_next - halo_next) ** 2).mean().item(), 0.0))
self.assertEqual(data.halo_prev, None)
if data.comm.rank == 1:
self.assertTrue(np.isclose(((data.halo_prev - halo_prev) ** 2).mean().item(), 0.0))
self.assertEqual(data.halo_next, None)

data = ht.ones((10, 2), split=0)

halo_next = torch.tensor(np.array([[1.0, 1.0], [1.0, 1.0]]))
halo_prev = torch.tensor(np.array([[1.0, 1.0], [1.0, 1.0]]))

data.get_halo(2)

if data.comm.rank == 0:
self.assertTrue(np.isclose(((data.halo_next - halo_next) ** 2).mean().item(), 0.0))
self.assertEqual(data.halo_prev, None)
if data.comm.rank == 1:
self.assertTrue(np.isclose(((data.halo_prev - halo_prev) ** 2).mean().item(), 0.0))
self.assertEqual(data.halo_next, None)

if data.comm.size == 3:

halo_1 = torch.tensor(np.array([[2], [8]]))
halo_2 = torch.tensor(np.array([[3], [9]]))
halo_3 = torch.tensor(np.array([[4], [10]]))
halo_4 = torch.tensor(np.array([[5], [11]]))

data.get_halo(1)

data_with_halos = data.array_with_halos

if data.comm.rank == 0:
self.assertTrue(torch.equal(data.halo_next, halo_2))
self.assertEqual(data.halo_prev, None)
self.assertEqual(data_with_halos.shape, (2, 3))
if data.comm.rank == 1:
self.assertTrue(torch.equal(data.halo_prev, halo_1))
self.assertTrue(torch.equal(data.halo_next, halo_4))
self.assertEqual(data_with_halos.shape, (2, 4))
if data.comm.rank == 2:
self.assertEqual(data.halo_next, None)
self.assertTrue(torch.equal(data.halo_prev, halo_3))
self.assertEqual(data_with_halos.shape, (2, 3))

# exception on wrong argument type in get_halo
with self.assertRaises(TypeError):
data.get_halo("wrong_type")
# exception on wrong argument in get_halo
with self.assertRaises(ValueError):
data.get_halo(-99)
# exception for too large halos
with self.assertRaises(ValueError):
data.get_halo(4)

def test_astype(self):
data = ht.float32([[1, 2, 3], [4, 5, 6]], device=ht_device)

Expand Down
9 changes: 6 additions & 3 deletions heat/core/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,12 @@ def __init__(self, arr, tiles_per_proc=2):
# then the local data needs to be redistributed to fit the full diagonal on as many
# processes as possible
# if any(lshape_map[..., arr.split] == 1):
last_diag_pr, col_per_proc_list, col_inds, tile_columns = self.__adjust_lshape_sp0_1tile(
arr, col_inds, lshape_map, tiles_per_proc
)
(
last_diag_pr,
col_per_proc_list,
col_inds,
tile_columns,
) = self.__adjust_lshape_sp0_1tile(arr, col_inds, lshape_map, tiles_per_proc)
# re-test for empty processes and remove empty rows
empties = torch.where(lshape_map[..., 0] == 0)[0]
if empties.numel() > 0:
Expand Down

0 comments on commit f71b83f

Please sign in to comment.