diff --git a/CHANGELOG.md b/CHANGELOG.md index f373a718d5..004b462d8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -224,6 +224,8 @@ Example on 2 processes: - [#664](https://github.com/helmholtz-analytics/heat/pull/664) New feature / enhancement: distributed `random.random_sample`, `random.random`, `random.sample`, `random.ranf`, `random.random_integer` - [#666](https://github.com/helmholtz-analytics/heat/pull/666) New feature: distributed prepend/append for `diff()`. - [#667](https://github.com/helmholtz-analytics/heat/pull/667) Enhancement `reshape`: rename axis parameter +- [#678](https://github.com/helmholtz-analytics/heat/pull/678) New feature: distributed `tile` +- [#670](https://github.com/helmholtz-analytics/heat/pull/670) New Feature: `bincount()` - [#674](https://github.com/helmholtz-analytics/heat/pull/674) New feature: `repeat` - [#670](https://github.com/helmholtz-analytics/heat/pull/670) New Feature: distributed `bincount()` - [#672](https://github.com/helmholtz-analytics/heat/pull/672) Bug / Enhancement: Remove `MPIRequest.wait()`, rewrite calls with capital letters. lower case `wait()` now falls back to the `mpi4py` function diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 2ba17b3c9a..e7242c360d 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -52,6 +52,7 @@ "squeeze", "stack", "swapaxes", + "tile", "topk", "unique", "vsplit", @@ -3596,6 +3597,262 @@ def vstack(arrays: Sequence[DNDarray, ...]) -> DNDarray: return concatenate(arrays, axis=0) +def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: + """ + Construct a new DNDarray by repeating 'x' the number of times given by 'reps'. + + If 'reps' has length 'd', the result will have 'max(d, x.ndim)' dimensions: + + - if 'x.ndim < d', 'x' is promoted to be d-dimensional by prepending new axes. + So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) + for 3-D replication (if this is not the desired behavior, promote 'x' to d-dimensions + manually before calling this function); + + - if 'x.ndim > d', 'reps' will replicate the last 'd' dimensions of 'x', i.e., if + 'x.shape' is (2, 3, 4, 5), a 'reps' of (2, 2) will be expanded to (1, 1, 2, 2). + + Parameters + ---------- + x : DNDarray + Input + + reps : Sequence[ints,...] + Repetitions + + Returns + ------- + tiled : DNDarray + Split semantics: if `x` is distributed, the tiled data will be distributed along the + same dimension. Note that nominally `tiled.split != x.split` in the case where + `len(reps) > x.ndim`. See example below. + + Examples + -------- + >>> x = ht.arange(12).reshape((4,3)).resplit_(0) + >>> x + DNDarray([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]], dtype=ht.int32, device=cpu:0, split=0) + >>> reps = (1, 2, 2) + >>> tiled = ht.tile(x, reps) + >>> tiled + DNDarray([[[ 0, 1, 2, 0, 1, 2], + [ 3, 4, 5, 3, 4, 5], + [ 6, 7, 8, 6, 7, 8], + [ 9, 10, 11, 9, 10, 11], + [ 0, 1, 2, 0, 1, 2], + [ 3, 4, 5, 3, 4, 5], + [ 6, 7, 8, 6, 7, 8], + [ 9, 10, 11, 9, 10, 11]]], dtype=ht.int32, device=cpu:0, split=1) + """ + # x can be DNDarray or scalar + try: + _ = x.larray + except AttributeError: + try: + _ = x.shape + raise TypeError("Input can be a DNDarray or a scalar, is {}".format(type(x))) + except AttributeError: + x = factories.array(x).reshape(1) + + x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim) + + # torch-proof args/kwargs: + # torch `reps`: int or sequence of ints; numpy `reps`: can be array-like + try: + _ = x_proxy.repeat(reps) + except TypeError: + # `reps` is array-like or contains non-int elements + try: + reps = resplit(reps, None).tolist() + except AttributeError: + try: + reps = reps.tolist() + except AttributeError: + try: + _ = x_proxy.repeat(reps) + except TypeError: + raise TypeError( + "reps must be a sequence of ints, got {}".format( + list(type(i) for i in reps) + ) + ) + except RuntimeError: + pass + except RuntimeError: + pass + + try: + reps = list(reps) + except TypeError: + # scalar to list + reps = [reps] + + # torch reps vs. numpy reps: dimensions + if len(reps) != x.ndim: + added_dims = abs(len(reps) - x.ndim) + if len(reps) > x.ndim: + new_shape = added_dims * (1,) + x.gshape + new_split = None if x.split is None else x.split + added_dims + x = x.reshape(new_shape, new_split=new_split) + else: + reps = added_dims * [1] + reps + + out_gshape = tuple(x_proxy.repeat(reps).shape) + + if not x.is_distributed() or reps[x.split] == 1: + # no repeats along the split axis: local operation + t_tiled = x.larray.repeat(reps) + out_gshape = tuple(x_proxy.repeat(reps).shape) + return DNDarray( + t_tiled, + out_gshape, + dtype=x.dtype, + split=x.split, + device=x.device, + comm=x.comm, + balanced=x.balanced, + ) + # repeats along the split axis, work along dim 0 + size = x.comm.Get_size() + rank = x.comm.Get_rank() + trans_axes = list(range(x.ndim)) + if x.split != 0: + trans_axes[0], trans_axes[x.split] = x.split, 0 + reps[0], reps[x.split] = reps[x.split], reps[0] + x = linalg.transpose(x, trans_axes) + x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim) + out_gshape = tuple(x_proxy.repeat(reps).shape) + + local_x = x.larray + + # allocate tiled DNDarray, at first tiled along split axis only + split_reps = [rep if i == x.split else 1 for i, rep in enumerate(reps)] + split_tiled_shape = tuple(x_proxy.repeat(split_reps).shape) + tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=x.split, comm=x.comm) + # collect slicing information from all processes. + slices_map = [] + for array in [x, tiled]: + counts, displs = array.counts_displs() + t_slices_starts = torch.tensor(displs, device=local_x.device) + t_slices_ends = t_slices_starts + torch.tensor(counts, device=local_x.device) + slices_map.append([t_slices_starts, t_slices_ends]) + + t_slices_x, t_slices_tiled = slices_map + + # keep track of repetitions: + # local_x_starts.shape, local_x_ends.shape changing from (size,) to (reps[split], size) + reps_indices = list(x.gshape[x.split] * rep for rep in (range(reps[x.split]))) + t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32, device=local_x.device).reshape( + len(reps_indices), 1 + ) + for i, t in enumerate(t_slices_x): + t = t.repeat((reps[x.split], 1)) + t += t_reps_indices + t_slices_x[i] = t + + # distribution logic on current rank: + distr_map = [] + slices_map = [] + for i in range(2): + if i == 0: + # send logic for x slices on rank + local_x_starts = t_slices_x[0][:, rank].reshape(reps[x.split], 1) + local_x_ends = t_slices_x[1][:, rank].reshape(reps[x.split], 1) + t_tiled_starts, t_tiled_ends = t_slices_tiled + else: + # recv logic for tiled slices on rank + local_x_starts, local_x_ends = t_slices_x + t_tiled_starts = t_slices_tiled[0][rank] + t_tiled_ends = t_slices_tiled[1][rank] + t_max_starts = torch.max(local_x_starts, t_tiled_starts) + t_min_ends = torch.min(local_x_ends, t_tiled_ends) + coords = torch.where(t_min_ends - t_max_starts > 0) + # remove repeat offset from slices if sending + if i == 0: + t_max_starts -= t_reps_indices + t_min_ends -= t_reps_indices + starts = t_max_starts[coords].unsqueeze_(0) + ends = t_min_ends[coords].unsqueeze_(0) + slices_map.append(torch.cat((starts, ends), dim=0)) + distr_map.append(coords) + + # bookkeeping in preparation for Alltoallv + send_map, recv_map = distr_map + send_rep, send_to_ranks = send_map + recv_rep, recv_from_ranks = recv_map + send_slices, recv_slices = slices_map + + # do not assume that `x` is balanced + _, displs = x.counts_displs() + offset_x = displs[rank] + # impose load-balance on output + offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split) + t_tiled = tiled.larray + + active_send_counts = send_slices.clone() + active_send_counts[0] *= -1 + active_send_counts = active_send_counts.sum(0) + active_recv_counts = recv_slices.clone() + active_recv_counts[0] *= -1 + active_recv_counts = active_recv_counts.sum(0) + send_slices -= offset_x + recv_slices -= offset_tiled + recv_buf = t_tiled.clone() + # we need as many Alltoallv calls as repeats along the split axis + for rep in range(reps[x.split]): + # send_data, send_counts, send_displs on rank + all_send_counts = [0] * size + all_send_displs = [0] * size + send_this_rep = torch.where(send_rep == rep)[0].tolist() + dest_this_rep = send_to_ranks[send_this_rep].tolist() + for i, j in zip(send_this_rep, dest_this_rep): + all_send_counts[j] = active_send_counts[i].item() + all_send_displs[j] = send_slices[0][i].item() + local_send_slice = [slice(None)] * x.ndim + local_send_slice[x.split] = slice( + all_send_displs[0], all_send_displs[0] + sum(all_send_counts) + ) + send_buf = local_x[local_send_slice].clone() + + # recv_data, recv_counts, recv_displs on rank + all_recv_counts = [0] * size + all_recv_displs = [0] * size + recv_this_rep = torch.where(recv_rep == rep)[0].tolist() + orig_this_rep = recv_from_ranks[recv_this_rep].tolist() + for i, j in zip(recv_this_rep, orig_this_rep): + all_recv_counts[j] = active_recv_counts[i].item() + all_recv_displs[j] = recv_slices[0][i].item() + local_recv_slice = [slice(None)] * x.ndim + local_recv_slice[x.split] = slice( + all_recv_displs[0], all_recv_displs[0] + sum(all_recv_counts) + ) + x.comm.Alltoallv( + (send_buf, all_send_counts, all_send_displs), + (recv_buf, all_recv_counts, all_recv_displs), + ) + t_tiled[local_recv_slice] = recv_buf[local_recv_slice] + + # finally tile along non-split axes if needed + reps[x.split] = 1 + tiled = DNDarray( + t_tiled.repeat(reps), + out_gshape, + dtype=x.dtype, + split=x.split, + device=x.device, + comm=x.comm, + balanced=True, + ) + if trans_axes != list(range(x.ndim)): + # transpose back to original shape + x = linalg.transpose(x, trans_axes) + tiled = linalg.transpose(tiled, trans_axes) + + return tiled + + def topk( a: DNDarray, k: int, diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index cea956269c..c8c2a1516e 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3267,6 +3267,83 @@ def test_swapaxes(self): with self.assertRaises(TypeError): ht.swapaxes(x, 4.9, "abc") + def test_tile(self): + # test local tile, tuple reps + x = ht.arange(12).reshape((4, 3)) + reps = (2, 1) + ht_tiled = ht.tile(x, reps) + np_tiled = np.tile(x.numpy(), reps) + self.assertTrue((np_tiled == ht_tiled.numpy()).all()) + self.assertTrue(ht_tiled.dtype is x.dtype) + + # test scalar x + x = ht.array(9.0) + reps = (2, 1) + ht_tiled = ht.tile(x, reps) + np_tiled = np.tile(x.numpy(), reps) + self.assertTrue((np_tiled == ht_tiled.numpy()).all()) + self.assertTrue(ht_tiled.dtype is x.dtype) + + # test distributed tile along split axis + # len(reps) > x.ndim + split = 1 + x = ht.random.randn(4, 3, split=split) + reps = ht.random.randint(2, 10, size=(4,)) + tiled_along_split = ht.tile(x, reps) + np_tiled_along_split = np.tile(x.numpy(), reps.tolist()) + self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) + self.assertTrue(tiled_along_split.dtype is x.dtype) + + # test distributed tile along non-zero split axis + # len(reps) > x.ndim + split = 0 + x = ht.random.randn(4, 3, split=split) + reps = np.random.randint(2, 10, size=(4,)) + tiled_along_split = ht.tile(x, reps) + np_tiled_along_split = np.tile(x.numpy(), reps) + self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) + self.assertTrue(tiled_along_split.dtype is x.dtype) + + # test distributed tile() on imbalanced DNDarray + x = ht.random.randn(100, split=0) + x = x[ht.where(x > 0)] + reps = 5 + imbalanced_tiled_along_split = ht.tile(x, reps) + np_imbalanced_tiled_along_split = np.tile(x.numpy(), reps) + self.assertTrue( + (imbalanced_tiled_along_split.numpy() == np_imbalanced_tiled_along_split).all() + ) + self.assertTrue(imbalanced_tiled_along_split.dtype is x.dtype) + self.assertTrue(imbalanced_tiled_along_split.is_balanced(force_check=True)) + + # test tile along non-split axis + # len(reps) < x.ndim + split = 1 + x = ht.random.randn(4, 5, 3, 10, dtype=ht.float64, split=split) + reps = (2, 2) + tiled_along_non_split = ht.tile(x, reps) + np_tiled_along_non_split = np.tile(x.numpy(), reps) + self.assertTrue((tiled_along_non_split.numpy() == np_tiled_along_non_split).all()) + self.assertTrue(tiled_along_non_split.dtype is x.dtype) + + # test tile along split axis + # len(reps) = x.ndim + split = 1 + x = ht.random.randn(3, 3, dtype=ht.float64, split=split) + reps = (2, 3) + tiled_along_split = ht.tile(x, reps) + np_tiled_along_split = np.tile(x.numpy(), reps) + self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) + self.assertTrue(tiled_along_split.dtype is x.dtype) + + # test exceptions + float_reps = (1, 2, 2, 1.5) + with self.assertRaises(TypeError): + tiled_along_split = ht.tile(x, float_reps) + arraylike_float_reps = torch.tensor(float_reps) + with self.assertRaises(TypeError): + tiled_along_split = ht.tile(x, arraylike_float_reps) + def test_topk(self): size = ht.MPI_WORLD.size if size == 1: