diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 07b4e48418..a220538cc8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -384,7 +384,7 @@ def __prephalo(self, start, end) -> torch.Tensor: return self.__array[ix].clone().contiguous() - def get_halo(self, halo_size: int) -> torch.Tensor: + def get_halo(self, halo_size: int, prev: bool = True, next: bool = True) -> torch.Tensor: """ Fetch halos of size ``halo_size`` from neighboring ranks and save them in ``self.halo_next/self.halo_prev``. @@ -392,6 +392,10 @@ def get_halo(self, halo_size: int) -> torch.Tensor: ---------- halo_size : int Size of the halo. + prev : bool, optional + If True, fetch the halo from the previous rank. Default: True. + next : bool, optional + If True, fetch the halo from the next rank. Default: True. """ if not isinstance(halo_size, int): raise TypeError( @@ -433,25 +437,29 @@ def get_halo(self, halo_size: int) -> torch.Tensor: req_list = [] # exchange data with next populated process - if rank != last_rank: - self.comm.Isend(a_next, next_rank) - res_prev = torch.zeros( - a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device - ) - req_list.append(self.comm.Irecv(res_prev, source=next_rank)) + if prev: + if rank != last_rank: + self.comm.Isend(a_next, next_rank) + if rank != first_rank: + res_prev = torch.zeros( + a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device + ) + req_list.append(self.comm.Irecv(res_prev, source=prev_rank)) - if rank != first_rank: - self.comm.Isend(a_prev, prev_rank) - res_next = torch.zeros( - a_next.size(), dtype=a_next.dtype, device=self.device.torch_device - ) - req_list.append(self.comm.Irecv(res_next, source=prev_rank)) + if next: + if rank != first_rank: + self.comm.Isend(a_prev, prev_rank) + if rank != last_rank: + res_next = torch.zeros( + a_next.size(), dtype=a_next.dtype, device=self.device.torch_device + ) + req_list.append(self.comm.Irecv(res_next, source=next_rank)) for req in req_list: req.Wait() - self.__halo_next = res_prev - self.__halo_prev = res_next + self.__halo_next = res_next + self.__halo_prev = res_prev self.__ishalo = True def __cat_halo(self) -> torch.Tensor: