diff --git a/CHANGELOG.md b/CHANGELOG.md index d54b06c055..51bd0a9c2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# Pending additions + +## Bug Fixes +- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension + + # v1.1.0 ## Highlights diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 2243c4c8e3..d61b882086 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -9,7 +9,7 @@ from inspect import stack from mpi4py import MPI from pathlib import Path -from typing import List, Union, Tuple, TypeVar +from typing import List, Union, Tuple, TypeVar, Optional warnings.simplefilter("always", ResourceWarning) @@ -80,6 +80,7 @@ def __init__( self.__ishalo = False self.__halo_next = None self.__halo_prev = None + self.__lshape_map = None # check for inconsistencies between torch and heat devices assert str(array.device) == device.torch_device @@ -276,6 +277,13 @@ def lshape(self) -> Tuple[int]: """ return tuple(self.__array.shape) + @property + def lshape_map(self) -> torch.Tensor: + """ + Returns the lshape map. If it hasn't been previously created then it will be created here. + """ + return self.create_lshape_map() + @property def real(self) -> DNDarray: """ @@ -568,11 +576,20 @@ def cpu(self) -> DNDarray: self.__device = devices.cpu return self - def create_lshape_map(self) -> torch.Tensor: + def create_lshape_map(self, force_check: bool = True) -> torch.Tensor: """ Generate a 'map' of the lshapes of the data on all processes. Units are ``(process rank, lshape)`` + + Parameters + ---------- + force_check : bool, optional + if False (default) and the lshape map has already been created, use the previous + result. Otherwise, create the lshape_map """ + if not force_check and self.__lshape_map is not None: + return self.__lshape_map + lshape_map = torch.zeros( (self.comm.size, self.ndim), dtype=torch.int, device=self.device.torch_device ) @@ -589,6 +606,7 @@ def create_lshape_map(self) -> torch.Tensor: ) self.comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) + self.__lshape_map = lshape_map return lshape_map def __float__(self) -> DNDarray: @@ -825,16 +843,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if rank in actives: key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - if key_step is not None and rank > actives[0]: - offset = (chunk_ends[rank - 1] - og_key_start) % key_step - if key_step > 2 and offset > 0: - key_start += key_step - offset - elif key_step == 2 and offset > 0: - key_start += (chunk_ends[rank - 1] - og_key_start) % key_step - if isinstance(key_start, torch.Tensor): - key_start = key_start.item() - if isinstance(key_stop, torch.Tensor): - key_stop = key_stop.item() + key_start, key_stop = self.__xitem_get_key_start_stop( + rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start + ) key[self.split] = slice(key_start, key_stop, key_step) lout[new_split] = ( math.ceil((key_stop - key_start) / key_step) @@ -994,7 +1005,9 @@ def ravel(self): """ return manipulations.ravel(self) - def redistribute_(self, lshape_map: torch.Tensor = None, target_map: torch.Tensor = None): + def redistribute_( + self, lshape_map: Optional[torch.Tensor] = None, target_map: Optional[torch.Tensor] = None + ): """ Redistributes the data of the :class:`DNDarray` *along the split axis* to match the given target map. This function does not modify the non-split dimensions of the ``DNDarray``. @@ -1055,9 +1068,6 @@ def redistribute_(self, lshape_map: torch.Tensor = None, target_map: torch.Tenso ) ) if target_map is None: # if no target map is given then it will balance the tensor - target_map = torch.zeros( - (self.comm.size, len(self.gshape)), dtype=int, device=self.device.torch_device - ) _, _, chk = self.comm.chunk(self.shape, self.split) target_map = lshape_map.clone() target_map[..., self.split] = 0 @@ -1152,6 +1162,8 @@ def redistribute_(self, lshape_map: torch.Tensor = None, target_map: torch.Tenso # (in the case that the second to last processes needs to get data from +1 and -1) self.redistribute_(lshape_map=lshape_map, target_map=target_map) + self.__lshape_map = target_map + def __redistribute_shuffle( self, snd_pr: Union[int, torch.Tensor], @@ -1355,6 +1367,18 @@ def __setitem__( [0., 1., 0., 0., 0.]]) """ key = getattr(key, "copy()", key) + try: + if value.split != self.split: + val_split = int(value.split) + sp = self.split + warnings.warn( + f"\nvalue.split {val_split} not equal to this DNDarray's split:" + f" {sp}. this may cause errors or unwanted behavior", + category=RuntimeWarning, + ) + except (AttributeError, TypeError): + pass + if isinstance(key, DNDarray) and key.ndim == self.ndim: # this splits the key into torch.Tensors in each dimension for advanced indexing lkey = [slice(None, None, None)] * self.ndim @@ -1380,92 +1404,150 @@ def __setitem__( kend = key[ell_ind + 1 :] slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) key = kst + slices + kend + + for c, k in enumerate(key): + try: + key[c] = k.item() + except (AttributeError, ValueError): + pass + key = tuple(key) if not self.is_distributed(): - self.__setter(key, value) - else: - # raise RuntimeError("split axis of array and the target value are not equal") removed - # this will occur if the local shapes do not match - rank = self.comm.rank - ends = [] - for pr in range(self.comm.size): - _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) - ends.append(e[self.split].stop - e[self.split].start) - ends = torch.tensor(ends, device=self.device.torch_device) - chunk_ends = ends.cumsum(dim=0) - chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) - _, _, chunk_slice = self.comm.chunk(self.shape, self.split) - chunk_start = chunk_slice[self.split].start - chunk_end = chunk_slice[self.split].stop - - if isinstance(key, tuple): - if isinstance(key[self.split], slice): - key = list(key) - key_start = key[self.split].start if key[self.split].start is not None else 0 - key_stop = ( - key[self.split].stop - if key[self.split].stop is not None - else self.gshape[self.split] - ) - if key_stop < 0: - key_stop = self.gshape[self.split] + key[self.split].stop - key_step = key[self.split].step - og_key_start = key_start - st_pr = torch.where(key_start < chunk_ends)[0] - st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - sp_pr = torch.where(key_stop >= chunk_starts)[0] - sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - actives = list(range(st_pr, sp_pr + 1)) - if rank in actives: - key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - key_stop = ( - ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] + return self.__setter(key, value) # returns None + + # raise RuntimeError("split axis of array and the target value are not equal") removed + # this will occur if the local shapes do not match + rank = self.comm.rank + ends = [] + for pr in range(self.comm.size): + _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) + ends.append(e[self.split].stop - e[self.split].start) + ends = torch.tensor(ends, device=self.device.torch_device) + chunk_ends = ends.cumsum(dim=0) + chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) + _, _, chunk_slice = self.comm.chunk(self.shape, self.split) + chunk_start = chunk_slice[self.split].start + chunk_end = chunk_slice[self.split].stop + + self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim) + + # if the value is a DNDarray, the divisions need to be balanced: + # this means that we need to know how much data is where for both DNDarrays + # if the value data is not in the right place, then it will need to be moved + + if isinstance(key[self.split], slice): + key = list(key) + key_start = key[self.split].start if key[self.split].start is not None else 0 + key_stop = ( + key[self.split].stop + if key[self.split].stop is not None + else self.gshape[self.split] + ) + if key_stop < 0: + key_stop = self.gshape[self.split] + key[self.split].stop + key_step = key[self.split].step + og_key_start = key_start + st_pr = torch.where(key_start < chunk_ends)[0] + st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size + sp_pr = torch.where(key_stop >= chunk_starts)[0] + sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 + actives = list(range(st_pr, sp_pr + 1)) + + if ( + isinstance(value, type(self)) + and value.split is not None + and value.shape[self.split] != self.shape[self.split] + ): + # setting elements in self with a DNDarray which is not the same size in the + # split dimension + local_keys = [] + # below is used if the target needs to be reshaped + target_reshape_map = torch.zeros( + (self.comm.size, self.ndim), dtype=torch.int, device=self.device.torch_device + ) + for r in range(self.comm.size): + if r not in actives: + loc_key = key.copy() + loc_key[self.split] = slice(0, 0, 0) + else: + key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] + key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] + key_start_l, key_stop_l = self.__xitem_get_key_start_stop( + r, actives, key_start_l, key_stop_l, key_step, chunk_ends, og_key_start + ) + loc_key = key.copy() + loc_key[self.split] = slice(key_start_l, key_stop_l, key_step) + + gout_full = torch.tensor( + self_proxy[loc_key].shape, device=self.device.torch_device ) - if key_step is not None and rank > actives[0]: - offset = (chunk_ends[rank - 1] - og_key_start) % key_step - if key_step > 2 and offset > 0: - key_start += key_step - offset - elif key_step == 2 and offset > 0: - key_start += (chunk_ends[rank - 1] - og_key_start) % key_step - if isinstance(key_start, torch.Tensor): - key_start = key_start.item() - if isinstance(key_stop, torch.Tensor): - key_stop = key_stop.item() - key[self.split] = slice(key_start, key_stop, key_step) - # todo: need to slice the values to be the right size... - if isinstance(value, (torch.Tensor, type(self))): - value_slice = [slice(None, None, None)] * value.ndim - step2 = key_step if key_step is not None else 1 - key_start = chunk_starts[rank] - og_key_start - key_stop = key_start + key_stop - slice_loc = ( - value.ndim - 1 if self.split > value.ndim - 1 else self.split - ) - value_slice[slice_loc] = slice( - key_start.item(), math.ceil(torch.true_divide(key_stop, step2)), 1 - ) - self.__setter(tuple(key), value[tuple(value_slice)]) - else: - self.__setter(tuple(key), value) - - elif isinstance(key[self.split], torch.Tensor): - key = list(key) - key[self.split] -= chunk_start - self.__setter(tuple(key), value) - - elif key[self.split] in range(chunk_start, chunk_end): - key = list(key) - key[self.split] = key[self.split] - chunk_start - self.__setter(tuple(key), value) - - elif key[self.split] < 0: - key = list(key) - if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): - key[self.split] = key[self.split] + self.shape[self.split] - chunk_start - self.__setter(tuple(key), value) + target_reshape_map[r] = gout_full + local_keys.append(loc_key) + + key = local_keys[rank] + value = value.redistribute(target_map=target_reshape_map) + + if rank not in actives: + return # non-active ranks can exit here + + chunk_starts_v = target_reshape_map[:, self.split] + value_slice = [slice(None, None, None)] * value.ndim + step2 = key_step if key_step is not None else 1 + key_start = (chunk_starts_v[rank] - og_key_start).item() + + if key_start < 0: + key_start = 0 + key_stop = key_start + key_stop + slice_loc = value.ndim - 1 if self.split > value.ndim - 1 else self.split + value_slice[slice_loc] = slice( + key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + ) + + self.__setter(tuple(key), value.larray) + return + + # if rank in actives: + if rank not in actives: + return # non-active ranks can exit here + key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] + key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] + key_start, key_stop = self.__xitem_get_key_start_stop( + rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start + ) + key[self.split] = slice(key_start, key_stop, key_step) + + # todo: need to slice the values to be the right size... + if isinstance(value, (torch.Tensor, type(self))): + # if its a torch tensor, it is assumed to exist on all processes + value_slice = [slice(None, None, None)] * value.ndim + step2 = key_step if key_step is not None else 1 + key_start = (chunk_starts[rank] - og_key_start).item() + if key_start < 0: + key_start = 0 + key_stop = key_start + key_stop + slice_loc = value.ndim - 1 if self.split > value.ndim - 1 else self.split + value_slice[slice_loc] = slice( + key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + ) + self.__setter(tuple(key), value[tuple(value_slice)]) else: - self.__setter(key, value) + self.__setter(tuple(key), value) + elif isinstance(key[self.split], torch.Tensor): + key = list(key) + key[self.split] -= chunk_start + self.__setter(tuple(key), value) + + elif key[self.split] in range(chunk_start, chunk_end): + key = list(key) + key[self.split] = key[self.split] - chunk_start + self.__setter(tuple(key), value) + + elif key[self.split] < 0: + key = list(key) + if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): + key[self.split] = key[self.split] + self.shape[self.split] - chunk_start + self.__setter(tuple(key), value) def __setter( self, @@ -1530,6 +1612,30 @@ def tolist(self, keepsplit: bool = False) -> List: return self.__array.tolist() + @staticmethod + def __xitem_get_key_start_stop( + rank: int, + actives: list, + key_st: int, + key_sp: int, + step: int, + ends: torch.Tensor, + og_key_st: int, + ) -> Tuple[int, int]: + # this does some basic logic for adjusting the starting and stoping of the a key for + # setitem and getitem + if step is not None and rank > actives[0]: + offset = (ends[rank - 1] - og_key_st) % step + if step > 2 and offset > 0: + key_st += step - offset + elif step == 2 and offset > 0: + key_st += (ends[rank - 1] - og_key_st) % step + if isinstance(key_st, torch.Tensor): + key_st = key_st.item() + if isinstance(key_sp, torch.Tensor): + key_sp = key_sp.item() + return key_st, key_sp + # HeAT imports at the end to break cyclic dependencies from . import complex_math diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 614ae1692a..ef0e3f305c 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -38,6 +38,7 @@ "hstack", "pad", "ravel", + "redistribute", "repeat", "reshape", "resplit", @@ -1434,6 +1435,63 @@ def ravel(a: DNDarray) -> DNDarray: return result +def redistribute( + arr: DNDarray, lshape_map: torch.Tensor = None, target_map: torch.Tensor = None +) -> DNDarray: + """ + Redistributes the data of the :class:`DNDarray` *along the split axis* to match the given target map. + This function does not modify the non-split dimensions of the ``DNDarray``. + This is an abstraction and extension of the balance function. + + Parameters + ---------- + arr: DNDarray + DNDarray to redistribute + lshape_map : torch.Tensor, optional + The current lshape of processes. + Units are ``[rank, lshape]``. + target_map : torch.Tensor, optional + The desired distribution across the processes. + Units are ``[rank, target lshape]``. + Note: the only important parts of the target map are the values along the split axis, + values which are not along this axis are there to mimic the shape of the ``lshape_map``. + + Examples + -------- + >>> st = ht.ones((50, 81, 67), split=2) + >>> target_map = torch.zeros((st.comm.size, 3), dtype=torch.int) + >>> target_map[0, 2] = 67 + >>> print(target_map) + [0/2] tensor([[ 0, 0, 67], + [0/2] [ 0, 0, 0], + [0/2] [ 0, 0, 0]], dtype=torch.int32) + [1/2] tensor([[ 0, 0, 67], + [1/2] [ 0, 0, 0], + [1/2] [ 0, 0, 0]], dtype=torch.int32) + [2/2] tensor([[ 0, 0, 67], + [2/2] [ 0, 0, 0], + [2/2] [ 0, 0, 0]], dtype=torch.int32) + >>> print(st.lshape) + [0/2] (50, 81, 23) + [1/2] (50, 81, 22) + [2/2] (50, 81, 22) + >>> ht.redistribute_(st, target_map=target_map) + >>> print(st.lshape) + [0/2] (50, 81, 67) + [1/2] (50, 81, 0) + [2/2] (50, 81, 0) + """ + arr2 = arr.copy() + arr2.redistribute_(lshape_map=lshape_map, target_map=target_map) + return arr2 + + +DNDarray.redistribute = lambda arr, lshape_map=None, target_map=None: redistribute( + arr, lshape_map, target_map +) +DNDarray.redistribute.__doc__ = redistribute.__doc__ + + def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarray: """ Creates a new `DNDarray` by repeating elements of array `a`. The output has diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 87f7f91be2..92e1182ec2 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -245,7 +245,8 @@ def test_astype(self): def test_balance_and_lshape_map(self): data = ht.zeros((70, 20), split=0) data = data[:50] - lshape_map = data.create_lshape_map() + data.lshape_map + lshape_map = data.create_lshape_map(force_check=False) # tests the property self.assertEqual(sum(lshape_map[..., 0]), 50) if sum(data.lshape) == 0: self.assertTrue(all(lshape_map[data.comm.rank] == 0)) @@ -986,6 +987,27 @@ def test_rshift(self): self.assertTrue(res == 0) def test_setitem_getitem(self): + # tests for bug #825 + a = ht.ones((102, 102), split=0) + setting = ht.zeros((100, 100), split=0) + a[1:-1, 1:-1] = setting + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + + a = ht.ones((102, 102), split=1) + setting = ht.zeros((30, 100), split=1) + a[-30:, 1:-1] = setting + self.assertTrue(ht.all(a[-30:, 1:-1] == 0)) + + a = ht.ones((102, 102), split=1) + setting = ht.zeros((100, 100), split=1) + a[1:-1, 1:-1] = setting + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + + a = ht.ones((102, 102), split=1) + setting = ht.zeros((100, 20), split=1) + a[1:-1, :20] = setting + self.assertTrue(ht.all(a[1:-1, :20] == 0)) + # tests for bug 730: a = ht.ones((10, 25, 30), split=1) if a.comm.size > 1: @@ -1068,7 +1090,7 @@ def test_setitem_getitem(self): # slice in 1st dim across 1 node (2nd) w/ singular second dim c = ht.zeros((13, 5), split=0) - c[8:12, 1] = 1 + c[8:12, ht.array(1)] = 1 b = c[8:12, np.int64(1)] self.assertTrue((b == 1).all()) self.assertEqual(b.gshape, (4,)) @@ -1311,6 +1333,11 @@ def test_setitem_getitem(self): a[..., ...] with self.assertRaises(ValueError): a[..., ...] = 1 + if a.comm.size > 1: + with self.assertRaises(ValueError): + x = ht.ones((10, 10), split=0) + setting = ht.zeros((8, 8), split=1) + x[1:-1, 1:-1] = setting def test_size_gnumel(self): a = ht.zeros((10, 10, 10), split=None)