From 639cd572116770508fe9f8ea19cf4feac0784086 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 29 Jun 2021 15:03:32 +0200 Subject: [PATCH 01/18] setitem can now set values with DNDarrays which are not the same size in the split dimension, lshape_map property added --- heat/core/dndarray.py | 307 ++++++++++++++++++++++++++++++------------ 1 file changed, 222 insertions(+), 85 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 631fd933cd..ae8b5de18d 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -9,7 +9,7 @@ from inspect import stack from mpi4py import MPI from pathlib import Path -from typing import List, Union, Tuple, TypeVar +from typing import List, Union, Tuple, TypeVar, Optional warnings.simplefilter("always", ResourceWarning) @@ -80,6 +80,7 @@ def __init__( self.__ishalo = False self.__halo_next = None self.__halo_prev = None + self.__lshape_map = None # check for inconsistencies between torch and heat devices assert str(array.device) == device.torch_device @@ -276,6 +277,13 @@ def lshape(self) -> Tuple[int]: """ return tuple(self.__array.shape) + @property + def lshape_map(self) -> torch.Tensor: + """ + Returns the lshape map. if it has been previously created then it will be created here + """ + return self.create_lshape_map() + @property def real(self) -> DNDarray: """ @@ -568,11 +576,21 @@ def cpu(self) -> DNDarray: self.__device = devices.cpu return self - def create_lshape_map(self) -> torch.Tensor: + def create_lshape_map(self, recreate: Optional[bool] = True) -> torch.Tensor: """ Generate a 'map' of the lshapes of the data on all processes. Units are ``(process rank, lshape)`` + + Parameters + ---------- + recreate : bool, optional + if False (default) and the lshape map has already been created, use the previous + result. Otherwise, create the lshape_map + Default: False """ + if not recreate and self.__lshape_map is not None: + return self.__lshape_map + lshape_map = torch.zeros( (self.comm.size, self.ndim), dtype=torch.int, device=self.device.torch_device ) @@ -589,6 +607,7 @@ def create_lshape_map(self) -> torch.Tensor: ) self.comm.Allreduce(MPI.IN_PLACE, lshape_map, MPI.SUM) + self.__lshape_map = lshape_map return lshape_map def __float__(self) -> DNDarray: @@ -1055,9 +1074,6 @@ def redistribute_(self, lshape_map: torch.Tensor = None, target_map: torch.Tenso ) ) if target_map is None: # if no target map is given then it will balance the tensor - target_map = torch.zeros( - (self.comm.size, len(self.gshape)), dtype=int, device=self.device.torch_device - ) _, _, chk = self.comm.chunk(self.shape, self.split) target_map = lshape_map.clone() target_map[..., self.split] = 0 @@ -1152,6 +1168,8 @@ def redistribute_(self, lshape_map: torch.Tensor = None, target_map: torch.Tenso # (in the case that the second to last processes needs to get data from +1 and -1) self.redistribute_(lshape_map=lshape_map, target_map=target_map) + self.__lshape_map = target_map + def __redistribute_shuffle( self, snd_pr: Union[int, torch.Tensor], @@ -1383,89 +1401,208 @@ def __setitem__( key = tuple(key) if not self.is_distributed(): - self.__setter(key, value) - else: - # raise RuntimeError("split axis of array and the target value are not equal") removed - # this will occur if the local shapes do not match - rank = self.comm.rank - ends = [] - for pr in range(self.comm.size): - _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) - ends.append(e[self.split].stop - e[self.split].start) - ends = torch.tensor(ends, device=self.device.torch_device) - chunk_ends = ends.cumsum(dim=0) - chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) - _, _, chunk_slice = self.comm.chunk(self.shape, self.split) - chunk_start = chunk_slice[self.split].start - chunk_end = chunk_slice[self.split].stop - - if isinstance(key, tuple): - if isinstance(key[self.split], slice): - key = list(key) - key_start = key[self.split].start if key[self.split].start is not None else 0 - key_stop = ( - key[self.split].stop - if key[self.split].stop is not None - else self.gshape[self.split] - ) - if key_stop < 0: - key_stop = self.gshape[self.split] + key[self.split].stop - key_step = key[self.split].step - og_key_start = key_start - st_pr = torch.where(key_start < chunk_ends)[0] - st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size - sp_pr = torch.where(key_stop >= chunk_starts)[0] - sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 - actives = list(range(st_pr, sp_pr + 1)) - if rank in actives: - key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] - key_stop = ( - ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - ) - if key_step is not None and rank > actives[0]: - offset = (chunk_ends[rank - 1] - og_key_start) % key_step + return self.__setter(key, value) # returns None + + # raise RuntimeError("split axis of array and the target value are not equal") removed + # this will occur if the local shapes do not match + rank = self.comm.rank + ends = [] + for pr in range(self.comm.size): + _, _, e = self.comm.chunk(self.shape, self.split, rank=pr) + ends.append(e[self.split].stop - e[self.split].start) + ends = torch.tensor(ends, device=self.device.torch_device) + chunk_ends = ends.cumsum(dim=0) + chunk_starts = torch.tensor([0] + chunk_ends.tolist(), device=self.device.torch_device) + _, _, chunk_slice = self.comm.chunk(self.shape, self.split) + chunk_start = chunk_slice[self.split].start + chunk_end = chunk_slice[self.split].stop + + if not isinstance(key, tuple): + return self.__setter(key, value) # returns None + + # if the value is a DNDarray, the divisions need to be balanced: + # this means that we need to know how much data is where for both DNDarrays + # if the value data is not in the right place, then it will need to be moved + + # if isinstance(key, tuple): + if isinstance(key[self.split], slice): + key = list(key) + key_start = key[self.split].start if key[self.split].start is not None else 0 + key_stop = ( + key[self.split].stop + if key[self.split].stop is not None + else self.gshape[self.split] + ) + if key_stop < 0: + key_stop = self.gshape[self.split] + key[self.split].stop + key_step = key[self.split].step + og_key_start = key_start + st_pr = torch.where(key_start < chunk_ends)[0] + st_pr = st_pr[0] if len(st_pr) > 0 else self.comm.size + sp_pr = torch.where(key_stop >= chunk_starts)[0] + sp_pr = sp_pr[-1] if len(sp_pr) > 0 else 0 + actives = list(range(st_pr, sp_pr + 1)) + + if ( + isinstance(value, type(self)) + and value.split is not None + and value.shape[self.split] != self.shape[self.split] + ): + # setting elements in self with a DNDarray which is not the same size in the + # split dimension + local_keys = [] + # below is used if the target needs to be reshaped + target_reshape_map = torch.zeros( + (self.comm.size, self.ndim), dtype=torch.int, device=self.device.torch_device + ) + self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim) + for r in range(self.comm.size): + if r not in actives: + loc_key = key.copy() + loc_key[self.split] = slice(0, 0, 0) + else: + key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] + key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] + if key_step is not None and r > actives[0]: + offset = (chunk_ends[r - 1] - og_key_start) % key_step if key_step > 2 and offset > 0: - key_start += key_step - offset + key_start_l += key_step - offset elif key_step == 2 and offset > 0: - key_start += (chunk_ends[rank - 1] - og_key_start) % key_step - if isinstance(key_start, torch.Tensor): - key_start = key_start.item() - if isinstance(key_stop, torch.Tensor): - key_stop = key_stop.item() - key[self.split] = slice(key_start, key_stop, key_step) - # todo: need to slice the values to be the right size... - if isinstance(value, (torch.Tensor, type(self))): - value_slice = [slice(None, None, None)] * value.ndim - step2 = key_step if key_step is not None else 1 - key_start = chunk_starts[rank] - og_key_start - key_stop = key_start + key_stop - slice_loc = ( - value.ndim - 1 if self.split > value.ndim - 1 else self.split - ) - value_slice[slice_loc] = slice( - key_start.item(), math.ceil(torch.true_divide(key_stop, step2)), 1 - ) - self.__setter(tuple(key), value[tuple(value_slice)]) - else: - self.__setter(tuple(key), value) - - elif isinstance(key[self.split], torch.Tensor): - key = list(key) - key[self.split] -= chunk_start - self.__setter(tuple(key), value) - - elif key[self.split] in range(chunk_start, chunk_end): - key = list(key) - key[self.split] = key[self.split] - chunk_start - self.__setter(tuple(key), value) - - elif key[self.split] < 0: - key = list(key) - if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): - key[self.split] = key[self.split] + self.shape[self.split] - chunk_start - self.__setter(tuple(key), value) + key_start_l += (chunk_ends[r - 1] - og_key_start) % key_step + if isinstance(key_start_l, torch.Tensor): + key_start_l = key_start_l.item() + if isinstance(key_stop_l, torch.Tensor): + key_stop_l = key_stop_l.item() + loc_key = key.copy() + loc_key[self.split] = slice(key_start_l, key_stop_l, key_step) + + gout_full = torch.tensor( + self_proxy[loc_key].shape, device=self.device.torch_device + ) + target_reshape_map[r] = gout_full + local_keys.append(loc_key) + + key = local_keys[rank] + value = value.redistribute(target_map=target_reshape_map) + + if rank not in actives: + return # non-active ranks can exit here + + chunk_starts_v = target_reshape_map[:, self.split] + value_slice = [slice(None, None, None)] * value.ndim + step2 = key_step if key_step is not None else 1 + key_start = (chunk_starts_v[rank] - og_key_start).item() + # print(key_start) + if key_start < 0: + key_start = 0 + key_stop = key_start + key_stop + slice_loc = value.ndim - 1 if self.split > value.ndim - 1 else self.split + value_slice[slice_loc] = slice( + key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + ) + + self.__setter(tuple(key), value.larray) + return + + # if rank in actives: + if rank not in actives: + return # non-active ranks can exit here + key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] + key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] + if key_step is not None and rank > actives[0]: + offset = (chunk_ends[rank - 1] - og_key_start) % key_step + if key_step > 2 and offset > 0: + key_start += key_step - offset + elif key_step == 2 and offset > 0: + key_start += (chunk_ends[rank - 1] - og_key_start) % key_step + if isinstance(key_start, torch.Tensor): + key_start = key_start.item() + if isinstance(key_stop, torch.Tensor): + key_stop = key_stop.item() + key[self.split] = slice(key_start, key_stop, key_step) + # key = local_keys[rank] + # print(key, chunk_starts) + # todo: need to slice the values to be the right size... + if isinstance(value, (torch.Tensor, type(self))): + # if its a torch tensor, it is assumed to exist on all processes + value_slice = [slice(None, None, None)] * value.ndim + step2 = key_step if key_step is not None else 1 + key_start = (chunk_starts[rank] - og_key_start).item() + if key_start < 0: + key_start = 0 + key_stop = key_start + key_stop + slice_loc = value.ndim - 1 if self.split > value.ndim - 1 else self.split + value_slice[slice_loc] = slice( + key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + ) + # print(key, value_slice) + self.__setter(tuple(key), value[tuple(value_slice)]) + # if isinstance(value, type(self)): + # # need to make sure that the data is in the right place here. + # # get expected shape (key is a tuple): + # # exp_shape = [] + # # ints = 0 + # # for c, k in enumerate(key): + # # if isinstance(k, int): + # # ints += 1 + # # elif isinstance(k, slice): + # # stp = k.stop if k.stop > 0 else self.gshape[c] + k.stop + # # start = k.start if k.start > 0 else self.gshape[c] + k.start + # # exp_shape.append((stp - start) // k.step) + # # if len(exp_shape) - ints > self.ndim: + # # for r in range(len(exp_shape) - ints - 1, self.ndim): + # # exp_shape.append(self.gshape[r]) + # # # todo: above needs to be modified to only do things for the split dim!! + # + # # next objective is finding the location of all of elements in self to set + # # only need to do this in the split dimension + # + # # compare the key dims/splits to the set/splits + # self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim) + # gout_full = torch.tensor(self_proxy[key].shape, device=self.device.torch_device) + # target_reshape_map = torch.zeros( + # (self.comm.size, self.ndim), + # dtype=torch.int, + # device=self.device.torch_device, + # ) + # target_reshape_map[self.comm.rank] = gout_full + # # self.comm.Allreduce(MPI.IN_PLACE, target_reshape_map, MPI.SUM) + # print(target_reshape_map) + # # value = value.redistribute(target_map=target_reshape_map) + # # print('h', value.lshape, key) + # # # print('h', gout_full, key) + # # + # # value_slice = [slice(None, None, None)] * value.ndim + # # step2 = key_step if key_step is not None else 1 + # # key_start = (chunk_starts[rank] - og_key_start).item() + # # if key_start < 0: + # # key_start = 0 + # # key_stop = key_start + key_stop + # # slice_loc = ( + # # value.ndim - 1 if self.split > value.ndim - 1 else self.split + # # ) + # # value_slice[slice_loc] = slice( + # # key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 + # # ) + # # print('v', value_slice) + # # self.__setter(tuple(key), value[tuple(value_slice)]) else: - self.__setter(key, value) + self.__setter(tuple(key), value) + elif isinstance(key[self.split], torch.Tensor): + key = list(key) + key[self.split] -= chunk_start + self.__setter(tuple(key), value) + + elif key[self.split] in range(chunk_start, chunk_end): + key = list(key) + key[self.split] = key[self.split] - chunk_start + self.__setter(tuple(key), value) + + elif key[self.split] < 0: + key = list(key) + if self.gshape[self.split] + key[self.split] in range(chunk_start, chunk_end): + key[self.split] = key[self.split] + self.shape[self.split] - chunk_start + self.__setter(tuple(key), value) def __setter( self, From d1a0e551692027a22f8803e2beabc2f412269531 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 29 Jun 2021 15:04:01 +0200 Subject: [PATCH 02/18] added new test cases (simple) --- heat/core/tests/test_dndarray.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 684b220584..08351c5b89 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -986,6 +986,17 @@ def test_rshift(self): self.assertTrue(res == 0) def test_setitem_getitem(self): + # tests for bug #825 + a = ht.ones((102, 102), split=0) + setting = ht.zeros((100, 100), split=0) + a[1:-1, 1:-1] = setting + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + + a = ht.ones((102, 102), split=1) + setting = ht.zeros((100, 100), split=1) + a[1:-1, 1:-1] = setting + self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + # tests for bug 730: a = ht.ones((10, 25, 30), split=1) if a.comm.size > 1: From 28c268fed7c469d123dd54d5156c433c9239639f Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 29 Jun 2021 15:05:03 +0200 Subject: [PATCH 03/18] added oop redistribute to manipulations --- heat/core/manipulations.py | 58 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 614ae1692a..ef0e3f305c 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -38,6 +38,7 @@ "hstack", "pad", "ravel", + "redistribute", "repeat", "reshape", "resplit", @@ -1434,6 +1435,63 @@ def ravel(a: DNDarray) -> DNDarray: return result +def redistribute( + arr: DNDarray, lshape_map: torch.Tensor = None, target_map: torch.Tensor = None +) -> DNDarray: + """ + Redistributes the data of the :class:`DNDarray` *along the split axis* to match the given target map. + This function does not modify the non-split dimensions of the ``DNDarray``. + This is an abstraction and extension of the balance function. + + Parameters + ---------- + arr: DNDarray + DNDarray to redistribute + lshape_map : torch.Tensor, optional + The current lshape of processes. + Units are ``[rank, lshape]``. + target_map : torch.Tensor, optional + The desired distribution across the processes. + Units are ``[rank, target lshape]``. + Note: the only important parts of the target map are the values along the split axis, + values which are not along this axis are there to mimic the shape of the ``lshape_map``. + + Examples + -------- + >>> st = ht.ones((50, 81, 67), split=2) + >>> target_map = torch.zeros((st.comm.size, 3), dtype=torch.int) + >>> target_map[0, 2] = 67 + >>> print(target_map) + [0/2] tensor([[ 0, 0, 67], + [0/2] [ 0, 0, 0], + [0/2] [ 0, 0, 0]], dtype=torch.int32) + [1/2] tensor([[ 0, 0, 67], + [1/2] [ 0, 0, 0], + [1/2] [ 0, 0, 0]], dtype=torch.int32) + [2/2] tensor([[ 0, 0, 67], + [2/2] [ 0, 0, 0], + [2/2] [ 0, 0, 0]], dtype=torch.int32) + >>> print(st.lshape) + [0/2] (50, 81, 23) + [1/2] (50, 81, 22) + [2/2] (50, 81, 22) + >>> ht.redistribute_(st, target_map=target_map) + >>> print(st.lshape) + [0/2] (50, 81, 67) + [1/2] (50, 81, 0) + [2/2] (50, 81, 0) + """ + arr2 = arr.copy() + arr2.redistribute_(lshape_map=lshape_map, target_map=target_map) + return arr2 + + +DNDarray.redistribute = lambda arr, lshape_map=None, target_map=None: redistribute( + arr, lshape_map, target_map +) +DNDarray.redistribute.__doc__ = redistribute.__doc__ + + def repeat(a: Iterable, repeats: Iterable, axis: Optional[int] = None) -> DNDarray: """ Creates a new `DNDarray` by repeating elements of array `a`. The output has From 61266c1c1aee00ca91afee2b45991bc9e0e768e9 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 29 Jun 2021 15:08:23 +0200 Subject: [PATCH 04/18] changelog update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c26939204..a7fb2cfcda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ Example on 2 processes: - [#790](https://github.com/helmholtz-analytics/heat/pull/790) catch incorrect device after `bcast` in `DNDarray.__getitem__` - [#811](https://github.com/helmholtz-analytics/heat/pull/811) Fixed memory leak in `DNDarray.larray` - [#821](https://github.com/helmholtz-analytics/heat/pull/821) Fixed `__getitem__` handling of distributed `DNDarray` key element +- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension ### Exponential - [#812](https://github.com/helmholtz-analytics/heat/pull/712) New feature: `logaddexp`, `logaddexp2` From 174c79259f47bddb30510655808a6d614429d377 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 29 Jun 2021 15:25:35 +0200 Subject: [PATCH 05/18] added more test cases to increase coveraged and removed some dead code --- heat/core/dndarray.py | 53 +------------------------------- heat/core/tests/test_dndarray.py | 10 ++++++ 2 files changed, 11 insertions(+), 52 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index ae8b5de18d..8a82d96564 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1520,8 +1520,7 @@ def __setitem__( if isinstance(key_stop, torch.Tensor): key_stop = key_stop.item() key[self.split] = slice(key_start, key_stop, key_step) - # key = local_keys[rank] - # print(key, chunk_starts) + # todo: need to slice the values to be the right size... if isinstance(value, (torch.Tensor, type(self))): # if its a torch tensor, it is assumed to exist on all processes @@ -1535,57 +1534,7 @@ def __setitem__( value_slice[slice_loc] = slice( key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 ) - # print(key, value_slice) self.__setter(tuple(key), value[tuple(value_slice)]) - # if isinstance(value, type(self)): - # # need to make sure that the data is in the right place here. - # # get expected shape (key is a tuple): - # # exp_shape = [] - # # ints = 0 - # # for c, k in enumerate(key): - # # if isinstance(k, int): - # # ints += 1 - # # elif isinstance(k, slice): - # # stp = k.stop if k.stop > 0 else self.gshape[c] + k.stop - # # start = k.start if k.start > 0 else self.gshape[c] + k.start - # # exp_shape.append((stp - start) // k.step) - # # if len(exp_shape) - ints > self.ndim: - # # for r in range(len(exp_shape) - ints - 1, self.ndim): - # # exp_shape.append(self.gshape[r]) - # # # todo: above needs to be modified to only do things for the split dim!! - # - # # next objective is finding the location of all of elements in self to set - # # only need to do this in the split dimension - # - # # compare the key dims/splits to the set/splits - # self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim) - # gout_full = torch.tensor(self_proxy[key].shape, device=self.device.torch_device) - # target_reshape_map = torch.zeros( - # (self.comm.size, self.ndim), - # dtype=torch.int, - # device=self.device.torch_device, - # ) - # target_reshape_map[self.comm.rank] = gout_full - # # self.comm.Allreduce(MPI.IN_PLACE, target_reshape_map, MPI.SUM) - # print(target_reshape_map) - # # value = value.redistribute(target_map=target_reshape_map) - # # print('h', value.lshape, key) - # # # print('h', gout_full, key) - # # - # # value_slice = [slice(None, None, None)] * value.ndim - # # step2 = key_step if key_step is not None else 1 - # # key_start = (chunk_starts[rank] - og_key_start).item() - # # if key_start < 0: - # # key_start = 0 - # # key_stop = key_start + key_stop - # # slice_loc = ( - # # value.ndim - 1 if self.split > value.ndim - 1 else self.split - # # ) - # # value_slice[slice_loc] = slice( - # # key_start, math.ceil(torch.true_divide(key_stop, step2)), 1 - # # ) - # # print('v', value_slice) - # # self.__setter(tuple(key), value[tuple(value_slice)]) else: self.__setter(tuple(key), value) elif isinstance(key[self.split], torch.Tensor): diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 08351c5b89..b05b98f42c 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -992,11 +992,21 @@ def test_setitem_getitem(self): a[1:-1, 1:-1] = setting self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + a = ht.ones((102, 102), split=1) + setting = ht.zeros((30, 100), split=1) + a[-30:, 1:-1] = setting + self.assertTrue(ht.all(a[-30:, 1:-1] == 0)) + a = ht.ones((102, 102), split=1) setting = ht.zeros((100, 100), split=1) a[1:-1, 1:-1] = setting self.assertTrue(ht.all(a[1:-1, 1:-1] == 0)) + a = ht.ones((102, 102), split=1) + setting = ht.zeros((100, 20), split=1) + a[1:-1, :20] = setting + self.assertTrue(ht.all(a[1:-1, :20] == 0)) + # tests for bug 730: a = ht.ones((10, 25, 30), split=1) if a.comm.size > 1: From 1938f4a110b697c1e367c8540eeaf968f2e960b1 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Thu, 8 Jul 2021 09:43:03 +0200 Subject: [PATCH 06/18] abstracted section of setitem: key slice generation --- heat/core/dndarray.py | 51 ++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 8a82d96564..4e0dfc5292 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -280,7 +280,7 @@ def lshape(self) -> Tuple[int]: @property def lshape_map(self) -> torch.Tensor: """ - Returns the lshape map. if it has been previously created then it will be created here + Returns the lshape map. If it has been previously created then it will be created here """ return self.create_lshape_map() @@ -576,7 +576,7 @@ def cpu(self) -> DNDarray: self.__device = devices.cpu return self - def create_lshape_map(self, recreate: Optional[bool] = True) -> torch.Tensor: + def create_lshape_map(self, recreate: bool = True) -> torch.Tensor: """ Generate a 'map' of the lshapes of the data on all processes. Units are ``(process rank, lshape)`` @@ -586,7 +586,6 @@ def create_lshape_map(self, recreate: Optional[bool] = True) -> torch.Tensor: recreate : bool, optional if False (default) and the lshape map has already been created, use the previous result. Otherwise, create the lshape_map - Default: False """ if not recreate and self.__lshape_map is not None: return self.__lshape_map @@ -1013,7 +1012,9 @@ def ravel(self): """ return manipulations.ravel(self) - def redistribute_(self, lshape_map: torch.Tensor = None, target_map: torch.Tensor = None): + def redistribute_( + self, lshape_map: Optional[torch.Tensor] = None, target_map: Optional[torch.Tensor] = None + ): """ Redistributes the data of the :class:`DNDarray` *along the split axis* to match the given target map. This function does not modify the non-split dimensions of the ``DNDarray``. @@ -1424,7 +1425,6 @@ def __setitem__( # this means that we need to know how much data is where for both DNDarrays # if the value data is not in the right place, then it will need to be moved - # if isinstance(key, tuple): if isinstance(key[self.split], slice): key = list(key) key_start = key[self.split].start if key[self.split].start is not None else 0 @@ -1463,16 +1463,9 @@ def __setitem__( else: key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] - if key_step is not None and r > actives[0]: - offset = (chunk_ends[r - 1] - og_key_start) % key_step - if key_step > 2 and offset > 0: - key_start_l += key_step - offset - elif key_step == 2 and offset > 0: - key_start_l += (chunk_ends[r - 1] - og_key_start) % key_step - if isinstance(key_start_l, torch.Tensor): - key_start_l = key_start_l.item() - if isinstance(key_stop_l, torch.Tensor): - key_stop_l = key_stop_l.item() + key_start_l, key_stop_l = self.__setitem_get_key_start_stop( + r, actives, key_start_l, key_stop_l, key_step, chunk_ends, og_key_start + ) loc_key = key.copy() loc_key[self.split] = slice(key_start_l, key_stop_l, key_step) @@ -1509,16 +1502,9 @@ def __setitem__( return # non-active ranks can exit here key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - if key_step is not None and rank > actives[0]: - offset = (chunk_ends[rank - 1] - og_key_start) % key_step - if key_step > 2 and offset > 0: - key_start += key_step - offset - elif key_step == 2 and offset > 0: - key_start += (chunk_ends[rank - 1] - og_key_start) % key_step - if isinstance(key_start, torch.Tensor): - key_start = key_start.item() - if isinstance(key_stop, torch.Tensor): - key_stop = key_stop.item() + key_start, key_stop = self.__setitem_get_key_start_stop( + rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start + ) key[self.split] = slice(key_start, key_stop, key_step) # todo: need to slice the values to be the right size... @@ -1553,6 +1539,21 @@ def __setitem__( key[self.split] = key[self.split] + self.shape[self.split] - chunk_start self.__setter(tuple(key), value) + @staticmethod + def __setitem_get_key_start_stop(rank, actives, key_st, key_sp, step, ends, og_key_st): + start, stop = None, None + if step is not None and rank > actives[0]: + offset = (ends[rank - 1] - og_key_st) % step + if step > 2 and offset > 0: + key_st += step - offset + elif step == 2 and offset > 0: + key_st += (ends[rank - 1] - og_key_st) % step + if isinstance(key_st, torch.Tensor): + start = key_st.item() + if isinstance(key_sp, torch.Tensor): + stop = key_sp.item() + return start, stop + def __setter( self, key: Union[int, Tuple[int, ...], List[int, ...]], From 94af9f411569c3acd44b5fb248fffdc9680071ad Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Thu, 8 Jul 2021 10:01:16 +0200 Subject: [PATCH 07/18] used key logic in getitem, added typehints/simple docstring to xitem_get_key_start_stop --- heat/core/dndarray.py | 57 +++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 4e0dfc5292..595ebd8bb6 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -845,16 +845,9 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar if rank in actives: key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] key_stop = counts[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - if key_step is not None and rank > actives[0]: - offset = (chunk_ends[rank - 1] - og_key_start) % key_step - if key_step > 2 and offset > 0: - key_start += key_step - offset - elif key_step == 2 and offset > 0: - key_start += (chunk_ends[rank - 1] - og_key_start) % key_step - if isinstance(key_start, torch.Tensor): - key_start = key_start.item() - if isinstance(key_stop, torch.Tensor): - key_stop = key_stop.item() + key_start, key_stop = self.__xitem_get_key_start_stop( + rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start + ) key[self.split] = slice(key_start, key_stop, key_step) lout[new_split] = ( math.ceil((key_stop - key_start) / key_step) @@ -1463,7 +1456,7 @@ def __setitem__( else: key_start_l = 0 if r != actives[0] else key_start - chunk_starts[r] key_stop_l = ends[r] if r != actives[-1] else key_stop - chunk_starts[r] - key_start_l, key_stop_l = self.__setitem_get_key_start_stop( + key_start_l, key_stop_l = self.__xitem_get_key_start_stop( r, actives, key_start_l, key_stop_l, key_step, chunk_ends, og_key_start ) loc_key = key.copy() @@ -1502,7 +1495,7 @@ def __setitem__( return # non-active ranks can exit here key_start = 0 if rank != actives[0] else key_start - chunk_starts[rank] key_stop = ends[rank] if rank != actives[-1] else key_stop - chunk_starts[rank] - key_start, key_stop = self.__setitem_get_key_start_stop( + key_start, key_stop = self.__xitem_get_key_start_stop( rank, actives, key_start, key_stop, key_step, chunk_ends, og_key_start ) key[self.split] = slice(key_start, key_stop, key_step) @@ -1539,21 +1532,6 @@ def __setitem__( key[self.split] = key[self.split] + self.shape[self.split] - chunk_start self.__setter(tuple(key), value) - @staticmethod - def __setitem_get_key_start_stop(rank, actives, key_st, key_sp, step, ends, og_key_st): - start, stop = None, None - if step is not None and rank > actives[0]: - offset = (ends[rank - 1] - og_key_st) % step - if step > 2 and offset > 0: - key_st += step - offset - elif step == 2 and offset > 0: - key_st += (ends[rank - 1] - og_key_st) % step - if isinstance(key_st, torch.Tensor): - start = key_st.item() - if isinstance(key_sp, torch.Tensor): - stop = key_sp.item() - return start, stop - def __setter( self, key: Union[int, Tuple[int, ...], List[int, ...]], @@ -1617,6 +1595,31 @@ def tolist(self, keepsplit: bool = False) -> List: return self.__array.tolist() + @staticmethod + def __xitem_get_key_start_stop( + rank: int, + actives: list, + key_st: int, + key_sp: int, + step: int, + ends: torch.Tensor, + og_key_st: int, + ) -> Tuple[int, int]: + # this does some basic logic for adjusting the starting and stoping of the a key for + # setitem and getitem + start, stop = None, None + if step is not None and rank > actives[0]: + offset = (ends[rank - 1] - og_key_st) % step + if step > 2 and offset > 0: + key_st += step - offset + elif step == 2 and offset > 0: + key_st += (ends[rank - 1] - og_key_st) % step + if isinstance(key_st, torch.Tensor): + start = key_st.item() + if isinstance(key_sp, torch.Tensor): + stop = key_sp.item() + return start, stop + # HeAT imports at the end to break cyclic dependencies from . import complex_math From e4f5364d3ac0ed8e2d3a2d904150a51b6ed74e13 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Thu, 8 Jul 2021 10:33:22 +0200 Subject: [PATCH 08/18] corrected false logic in key start stop adjustments --- heat/core/dndarray.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 595ebd8bb6..a05ac58105 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1607,7 +1607,6 @@ def __xitem_get_key_start_stop( ) -> Tuple[int, int]: # this does some basic logic for adjusting the starting and stoping of the a key for # setitem and getitem - start, stop = None, None if step is not None and rank > actives[0]: offset = (ends[rank - 1] - og_key_st) % step if step > 2 and offset > 0: @@ -1615,10 +1614,10 @@ def __xitem_get_key_start_stop( elif step == 2 and offset > 0: key_st += (ends[rank - 1] - og_key_st) % step if isinstance(key_st, torch.Tensor): - start = key_st.item() + key_st = key_st.item() if isinstance(key_sp, torch.Tensor): - stop = key_sp.item() - return start, stop + key_sp = key_sp.item() + return key_st, key_sp # HeAT imports at the end to break cyclic dependencies From 2a82e058eda777b8e831447f2cc9a0ba72d5dcd0 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 13 Jul 2021 15:27:27 +0200 Subject: [PATCH 09/18] added a raise in setitem for when the value and self have different split axes --- heat/core/dndarray.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index a05ac58105..59ef9c7fec 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -576,18 +576,18 @@ def cpu(self) -> DNDarray: self.__device = devices.cpu return self - def create_lshape_map(self, recreate: bool = True) -> torch.Tensor: + def create_lshape_map(self, force_check: bool = True) -> torch.Tensor: """ Generate a 'map' of the lshapes of the data on all processes. Units are ``(process rank, lshape)`` Parameters ---------- - recreate : bool, optional + force_check : bool, optional if False (default) and the lshape map has already been created, use the previous result. Otherwise, create the lshape_map """ - if not recreate and self.__lshape_map is not None: + if not force_check and self.__lshape_map is not None: return self.__lshape_map lshape_map = torch.zeros( @@ -1367,6 +1367,16 @@ def __setitem__( [0., 1., 0., 0., 0.]]) """ key = getattr(key, "copy()", key) + try: + if value.split != self.split: + warnings.warn( + f"\nvalue.split {value.split} not equal to this DNDarray's split:" + f" {self.split}. this may cause errors or unwanted behavior", + category=RuntimeWarning, + ) + except AttributeError: + pass + if isinstance(key, DNDarray) and key.ndim == self.ndim: # this splits the key into torch.Tensors in each dimension for advanced indexing lkey = [slice(None, None, None)] * self.ndim From 30e2df495f47a4b10e863c07df62084f1b76a5aa Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 13 Jul 2021 15:47:56 +0200 Subject: [PATCH 10/18] added handling for single value DNDarrays in key for setitem --- heat/core/dndarray.py | 10 +++++++++- heat/core/tests/test_dndarray.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index ae668a8640..aa217e5527 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1402,6 +1402,13 @@ def __setitem__( kend = key[ell_ind + 1 :] slices = [slice(None)] * (self.ndim - (len(kst) + len(kend))) key = kst + slices + kend + + for c, k in enumerate(key): + try: + key[c] = k.item() + except AttributeError: + pass + key = tuple(key) if not self.is_distributed(): @@ -1421,6 +1428,8 @@ def __setitem__( chunk_start = chunk_slice[self.split].start chunk_end = chunk_slice[self.split].stop + self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim) + if not isinstance(key, tuple): return self.__setter(key, value) # returns None @@ -1458,7 +1467,6 @@ def __setitem__( target_reshape_map = torch.zeros( (self.comm.size, self.ndim), dtype=torch.int, device=self.device.torch_device ) - self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim) for r in range(self.comm.size): if r not in actives: loc_key = key.copy() diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 3f967bac53..6f10508bac 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1089,7 +1089,7 @@ def test_setitem_getitem(self): # slice in 1st dim across 1 node (2nd) w/ singular second dim c = ht.zeros((13, 5), split=0) - c[8:12, 1] = 1 + c[8:12, ht.array(1)] = 1 b = c[8:12, np.int64(1)] self.assertTrue((b == 1).all()) self.assertEqual(b.gshape, (4,)) From 755c786f81d8b4e5c56f30ecd2138cd8a062efa0 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 13 Jul 2021 16:41:13 +0200 Subject: [PATCH 11/18] corrected try/expect in setitem to work with torch tensors as well --- heat/core/dndarray.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index aa217e5527..f8b553c719 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1369,12 +1369,14 @@ def __setitem__( key = getattr(key, "copy()", key) try: if value.split != self.split: + val_split = int(value.split) + sp = self.split warnings.warn( - f"\nvalue.split {value.split} not equal to this DNDarray's split:" - f" {self.split}. this may cause errors or unwanted behavior", + f"\nvalue.split {val_split} not equal to this DNDarray's split:" + f" {sp}. this may cause errors or unwanted behavior", category=RuntimeWarning, ) - except AttributeError: + except (AttributeError, TypeError): pass if isinstance(key, DNDarray) and key.ndim == self.ndim: @@ -1406,7 +1408,7 @@ def __setitem__( for c, k in enumerate(key): try: key[c] = k.item() - except AttributeError: + except (AttributeError, ValueError): pass key = tuple(key) From 8d0833021e09137f5c62cdcaa7e195165fb5f38d Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 20 Jul 2021 09:40:15 +0200 Subject: [PATCH 12/18] removing dead code --- heat/core/dndarray.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index f8b553c719..1dd3e8eae1 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1432,9 +1432,6 @@ def __setitem__( self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim) - if not isinstance(key, tuple): - return self.__setter(key, value) # returns None - # if the value is a DNDarray, the divisions need to be balanced: # this means that we need to know how much data is where for both DNDarrays # if the value data is not in the right place, then it will need to be moved @@ -1498,7 +1495,7 @@ def __setitem__( value_slice = [slice(None, None, None)] * value.ndim step2 = key_step if key_step is not None else 1 key_start = (chunk_starts_v[rank] - og_key_start).item() - # print(key_start) + if key_start < 0: key_start = 0 key_stop = key_start + key_stop From f193b0bb80596b75c4e8c93449b0fb5daea5500a Mon Sep 17 00:00:00 2001 From: Daniel Coquelin Date: Tue, 20 Jul 2021 09:42:03 +0200 Subject: [PATCH 13/18] Verb correction in lshape map creation Co-authored-by: mtar --- heat/core/dndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 1dd3e8eae1..d61b882086 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -280,7 +280,7 @@ def lshape(self) -> Tuple[int]: @property def lshape_map(self) -> torch.Tensor: """ - Returns the lshape map. If it has been previously created then it will be created here + Returns the lshape map. If it hasn't been previously created then it will be created here. """ return self.create_lshape_map() From 34ba9c58a41271363667a2513fb544e223b0a1f7 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 20 Jul 2021 09:43:08 +0200 Subject: [PATCH 14/18] new changelog to add pending additions again --- CHANGELOG.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 53b0c82970..d81b4e92a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# Pending additions + +## Bug Fixes +- [#831](https://github.com/helmholtz-analytics/heat/pull/831) `__getitem__` handling of `array-like` 1-element key + + # v1.1.0 ## Highlights @@ -50,7 +56,6 @@ Example on 2 processes: - [#820](https://github.com/helmholtz-analytics/heat/pull/820) `randn` values are pushed away from 0 by the minimum value the given dtype before being transformed into the Gaussian shape - [#821](https://github.com/helmholtz-analytics/heat/pull/821) Fixed `__getitem__` handling of distributed `DNDarray` key element - [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension -- [#831](https://github.com/helmholtz-analytics/heat/pull/831) `__getitem__` handling of `array-like` 1-element key ## Feature additions ### Exponential From 3ec7b448e1c94af40875b0ac0bfa638d89b49e41 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 20 Jul 2021 10:58:44 +0200 Subject: [PATCH 15/18] added tests for lshape map property and forced creation --- heat/core/tests/test_dndarray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 6f10508bac..51a51c55aa 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -245,7 +245,8 @@ def test_astype(self): def test_balance_and_lshape_map(self): data = ht.zeros((70, 20), split=0) data = data[:50] - lshape_map = data.create_lshape_map() + data.lshape_map + lshape_map = data.create_lshape_map(force_check=False) # tests the property self.assertEqual(sum(lshape_map[..., 0]), 50) if sum(data.lshape) == 0: self.assertTrue(all(lshape_map[data.comm.rank] == 0)) From cf8aa1adc61c39025fcd271aea773972dcb5ccc5 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 20 Jul 2021 11:10:48 +0200 Subject: [PATCH 16/18] corrected incorrect changelog, wrong line was moved the the pending additions --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d81b4e92a3..51bd0a9c2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ # Pending additions ## Bug Fixes -- [#831](https://github.com/helmholtz-analytics/heat/pull/831) `__getitem__` handling of `array-like` 1-element key +- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension # v1.1.0 @@ -55,7 +55,7 @@ Example on 2 processes: - [#811](https://github.com/helmholtz-analytics/heat/pull/811) Fixed memory leak in `DNDarray.larray` - [#820](https://github.com/helmholtz-analytics/heat/pull/820) `randn` values are pushed away from 0 by the minimum value the given dtype before being transformed into the Gaussian shape - [#821](https://github.com/helmholtz-analytics/heat/pull/821) Fixed `__getitem__` handling of distributed `DNDarray` key element -- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension +- [#831](https://github.com/helmholtz-analytics/heat/pull/831) `__getitem__` handling of `array-like` 1-element key ## Feature additions ### Exponential From bf39b2569b41fac5d5e2aca93f99abfcc0742893 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 20 Jul 2021 13:25:35 +0200 Subject: [PATCH 17/18] added raise test for splits != case in setitem --- heat/core/tests/test_dndarray.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index 51a51c55aa..ac62846136 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1333,6 +1333,10 @@ def test_setitem_getitem(self): a[..., ...] with self.assertRaises(ValueError): a[..., ...] = 1 + with self.assertRaises(ValueError): + x = ht.ones((10, 10), split=0) + setting = ht.zeros((8, 8), split=1) + x[1:-1, 1:-1] = setting def test_size_gnumel(self): a = ht.zeros((10, 10, 10), split=None) From 89fb97715b0fe8bf82b680ca8f909041fc91cd89 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 20 Jul 2021 13:35:51 +0200 Subject: [PATCH 18/18] new raise test now only runs on multiple processes --- heat/core/tests/test_dndarray.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index ac62846136..92e1182ec2 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1333,10 +1333,11 @@ def test_setitem_getitem(self): a[..., ...] with self.assertRaises(ValueError): a[..., ...] = 1 - with self.assertRaises(ValueError): - x = ht.ones((10, 10), split=0) - setting = ht.zeros((8, 8), split=1) - x[1:-1, 1:-1] = setting + if a.comm.size > 1: + with self.assertRaises(ValueError): + x = ht.ones((10, 10), split=0) + setting = ht.zeros((8, 8), split=1) + x[1:-1, 1:-1] = setting def test_size_gnumel(self): a = ht.zeros((10, 10, 10), split=None)