From 06e5d8c9f638fb5729ebd37495a27be48c76a0ec Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 8 Sep 2020 03:40:13 +0200 Subject: [PATCH 01/32] Implement tile() for local repeats --- heat/core/manipulations.py | 61 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index db8fb7a321..c86c2c9ca7 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -33,6 +33,7 @@ "sort", "squeeze", "stack", + "tile", "topk", "unique", "vstack", @@ -2155,6 +2156,66 @@ def vstack(tup): return concatenate(tup, axis=0) +def tile(x, reps): + """ + Construct an array by repeating A the number of times given by reps. + + If reps has length d, the result will have dimension of max(d, A.ndim). + + If A.ndim < d, A is promoted to be d-dimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication. If this is not the desired behavior, promote A to d-dimensions manually before calling this function. + + If A.ndim > d, reps is promoted to A.ndim by pre-pending 1’s to it. Thus for an A of shape (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2). + """ + # input sanitation + # x is DNDarray + # x.dim >= 1 + + # calculate map of new gshape, lshape + + if len(reps) > x.ndim: + added_dims = len(reps) - x.ndim + new_shape = added_dims * (1,) + x.gshape + new_split = None if x.split is None else x.split + added_dims + x = x.reshape(new_shape, axis=new_split) + split = x.split + if split is None or reps[split] == 1: + # no repeats along the split axis: local operation + t_x = x._DNDarray__array + t_tiled = t_x.repeat(reps) + return factories.array(t_tiled, dtype=x.dtype, split=split, comm=x.comm) + else: + raise NotImplementedError("ht.tile() not implemented yet for repeats along the split axis") + # size = x.comm.Get_size() + # rank = x.comm.Get_rank() + # # repeats along the split axis: communication needed + # output_shape = tuple(s * r for s, r in zip(x.gshape, reps)) + # tiled = factories.empty(output_shape, dtype=x.dtype, split=split, comm=x.comm) + # current_offset, current_lshape, current_slice = x.comm.chunk(x.gshape, split) + # tiled_offset, tiled_lshape, tiled_slice = tiled.comm.chunk(tiled.gshape, split) + # t_current_map = x.create_lshape_map() + # t_tiled_map = tiled.create_lshape_map() + # # map offsets (torch tensor with shape (size, 2) ) + # t_offset_map = torch.stack( + # ( + # t_current_map[:, split].cumsum(0) - t_current_map[:, split], + # t_tiled_map[:, split].cumsum(0) - t_tiled_map[:, split], + # t_tiled_map[rank, split] - t_current_map[:, split] + 1, + # ), + # dim=1, + # ) + + # # col 0 = current offsets, col 1 = tiled offsets + # recv_rank = torch.where( + # 0 + # <= t_offset_map[:, 0] - t_offset_map[:, 1] + # <= t_tiled_map[:, split] - t_current_map[:, split] + 1 + # ) + + # # use distributed setitem! + # # then torch.repeat on non-distributed dimensions + # pass + + def topk(a, k, dim=None, largest=True, sorted=True, out=None): """ Returns the k highest entries in the array. From 29521435f6dcc0a6eee142ecba8f7dba5d8dd953 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 8 Sep 2020 04:35:25 +0200 Subject: [PATCH 02/32] Fix output split in non_distributed tiling --- heat/core/manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index c86c2c9ca7..c84ffcdded 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2182,7 +2182,7 @@ def tile(x, reps): # no repeats along the split axis: local operation t_x = x._DNDarray__array t_tiled = t_x.repeat(reps) - return factories.array(t_tiled, dtype=x.dtype, split=split, comm=x.comm) + return factories.array(t_tiled, dtype=x.dtype, is_split=split, comm=x.comm) else: raise NotImplementedError("ht.tile() not implemented yet for repeats along the split axis") # size = x.comm.Get_size() From f10af0cc738c15139a39eb9a68be0809998682b7 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 11 Sep 2020 11:39:54 +0200 Subject: [PATCH 03/32] Implement distributed tile() --- heat/core/manipulations.py | 129 ++++++++++++++++++++++++++++--------- 1 file changed, 98 insertions(+), 31 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index c84ffcdded..8fb801e060 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2178,42 +2178,109 @@ def tile(x, reps): new_split = None if x.split is None else x.split + added_dims x = x.reshape(new_shape, axis=new_split) split = x.split + t_x = x._DNDarray__array if split is None or reps[split] == 1: # no repeats along the split axis: local operation - t_x = x._DNDarray__array t_tiled = t_x.repeat(reps) return factories.array(t_tiled, dtype=x.dtype, is_split=split, comm=x.comm) else: - raise NotImplementedError("ht.tile() not implemented yet for repeats along the split axis") - # size = x.comm.Get_size() - # rank = x.comm.Get_rank() - # # repeats along the split axis: communication needed - # output_shape = tuple(s * r for s, r in zip(x.gshape, reps)) - # tiled = factories.empty(output_shape, dtype=x.dtype, split=split, comm=x.comm) - # current_offset, current_lshape, current_slice = x.comm.chunk(x.gshape, split) - # tiled_offset, tiled_lshape, tiled_slice = tiled.comm.chunk(tiled.gshape, split) - # t_current_map = x.create_lshape_map() - # t_tiled_map = tiled.create_lshape_map() - # # map offsets (torch tensor with shape (size, 2) ) - # t_offset_map = torch.stack( - # ( - # t_current_map[:, split].cumsum(0) - t_current_map[:, split], - # t_tiled_map[:, split].cumsum(0) - t_tiled_map[:, split], - # t_tiled_map[rank, split] - t_current_map[:, split] + 1, - # ), - # dim=1, - # ) - - # # col 0 = current offsets, col 1 = tiled offsets - # recv_rank = torch.where( - # 0 - # <= t_offset_map[:, 0] - t_offset_map[:, 1] - # <= t_tiled_map[:, split] - t_current_map[:, split] + 1 - # ) - - # # use distributed setitem! - # # then torch.repeat on non-distributed dimensions - # pass + # repeats along the split axis: communication needed + size = x.comm.Get_size() + rank = x.comm.Get_rank() + x_shape = x.gshape + # allocate tiled DNDarray + tiled_shape = tuple(s * r for s, r in zip(x_shape, reps)) + tiled = factories.empty(tiled_shape, dtype=x.dtype, split=split, comm=x.comm) + # collect slicing information from all processes. + # t_(...) indicates process-local torch tensors + lshape_maps = [] + slices_map = [] + t_0 = torch.tensor([0], dtype=torch.int32) + for array in [x, tiled]: + t_lshape_map = array.create_lshape_map() + lshape_maps.append(t_lshape_map) + t_slices = torch.cat((t_0, t_lshape_map[:, split])).cumsum(0) + t_slices_starts = t_slices[:size] + t_slices_ends = t_slices[1:] + slices_map.append([t_slices_starts, t_slices_ends]) + + t_slices_x, t_slices_tiled = slices_map + + # keep track of repetitions + # t_x_starts.shape, t_x_ends.shape changing from (size,) to (reps[split], size) + reps_indices = list(x_shape[split] * rep for rep in (range(reps[split]))) + t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32).reshape(len(reps_indices), 1) + for i, t in enumerate(t_slices_x): + t = t.repeat((reps[split], 1)) + t += t_reps_indices + t_slices_x[i] = t + + # distribution logic on current rank: + distr_map = [] + slices_map = [] + for i in range(2): + if i == 0: + # send logic for x slices on rank + t_x_starts = t_slices_x[0][:, rank].reshape(reps[split], 1) + t_x_ends = t_slices_x[1][:, rank].reshape(reps[split], 1) + t_tiled_starts, t_tiled_ends = t_slices_tiled + else: + # recv logic for tiled slices on rank + t_x_starts, t_x_ends = t_slices_x + t_tiled_starts = t_slices_tiled[0][rank] + t_tiled_ends = t_slices_tiled[1][rank] + t_max_starts = torch.max(t_x_starts, t_tiled_starts) + t_min_ends = torch.min(t_x_ends, t_tiled_ends) + coords = torch.where(t_min_ends - t_max_starts > 0) + # remove repeat offset from slices if sending + if i == 0: + t_max_starts -= t_reps_indices + t_min_ends -= t_reps_indices + starts = t_max_starts[coords] + ends = t_min_ends[coords] + slices_map.append([starts, ends]) + distr_map.append(coords) + + send_map, recv_map = distr_map + send_slices, recv_slices = slices_map + send_to_ranks = send_map[1].tolist() + recv_from_ranks = recv_map[1].tolist() + + # allocate local buffer for incoming data + t_local_tile = torch.zeros( + tuple(lshape_maps[1][0].tolist()), dtype=x._DNDarray__array.dtype + ) + t_tiled = tiled._DNDarray__array + + send_slice = recv_slice = local_slice = tiled.ndim * (slice(None, None, None),) + offset_x, _, _ = x.comm.chunk(x.gshape, x.split) + offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split) + + for i, r in enumerate(send_to_ranks): + start = send_slices[0][i] - offset_x + stop = send_slices[1][i] - offset_x + send_slice = send_slice[:split] + (slice(start, stop, None),) + send_slice[split + 1 :] + local_slice = ( + local_slice[:split] + (slice(0, stop - start, None),) + local_slice[split + 1 :] + ) + t_local_tile[local_slice] = t_x[send_slice] + x.comm.Send(t_local_tile, r) + for i, r in enumerate(recv_from_ranks): + start = recv_slices[0][i] - offset_tiled + stop = recv_slices[1][i] - offset_tiled + recv_slice = recv_slice[:split] + (slice(start, stop, None),) + recv_slice[split + 1 :] + local_slice = ( + local_slice[:split] + (slice(0, stop - start, None),) + local_slice[split + 1 :] + ) + x.comm.Recv(t_local_tile, r) + tiled._DNDarray__array[recv_slice] = t_local_tile[local_slice] + + # finally tile along non-split axes if needed. Needs change in slices definition above + # reps = list(reps) + # reps[split] = 1 + # tiled._DNDarray__array = tiled._DNDarray__array.repeat(reps) + + return tiled def topk(a, k, dim=None, largest=True, sorted=True, out=None): From ee7eb5867e21bf1aaaaf53fc28332d12d6e5abdc Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 8 Sep 2020 03:40:13 +0200 Subject: [PATCH 04/32] Implement tile() for local repeats --- heat/core/manipulations.py | 61 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 32f382669b..9a834b50fc 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -33,6 +33,7 @@ "sort", "squeeze", "stack", + "tile", "topk", "unique", "vstack", @@ -2155,6 +2156,66 @@ def vstack(tup): return concatenate(tup, axis=0) +def tile(x, reps): + """ + Construct an array by repeating A the number of times given by reps. + + If reps has length d, the result will have dimension of max(d, A.ndim). + + If A.ndim < d, A is promoted to be d-dimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication. If this is not the desired behavior, promote A to d-dimensions manually before calling this function. + + If A.ndim > d, reps is promoted to A.ndim by pre-pending 1’s to it. Thus for an A of shape (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2). + """ + # input sanitation + # x is DNDarray + # x.dim >= 1 + + # calculate map of new gshape, lshape + + if len(reps) > x.ndim: + added_dims = len(reps) - x.ndim + new_shape = added_dims * (1,) + x.gshape + new_split = None if x.split is None else x.split + added_dims + x = x.reshape(new_shape, axis=new_split) + split = x.split + if split is None or reps[split] == 1: + # no repeats along the split axis: local operation + t_x = x._DNDarray__array + t_tiled = t_x.repeat(reps) + return factories.array(t_tiled, dtype=x.dtype, split=split, comm=x.comm) + else: + raise NotImplementedError("ht.tile() not implemented yet for repeats along the split axis") + # size = x.comm.Get_size() + # rank = x.comm.Get_rank() + # # repeats along the split axis: communication needed + # output_shape = tuple(s * r for s, r in zip(x.gshape, reps)) + # tiled = factories.empty(output_shape, dtype=x.dtype, split=split, comm=x.comm) + # current_offset, current_lshape, current_slice = x.comm.chunk(x.gshape, split) + # tiled_offset, tiled_lshape, tiled_slice = tiled.comm.chunk(tiled.gshape, split) + # t_current_map = x.create_lshape_map() + # t_tiled_map = tiled.create_lshape_map() + # # map offsets (torch tensor with shape (size, 2) ) + # t_offset_map = torch.stack( + # ( + # t_current_map[:, split].cumsum(0) - t_current_map[:, split], + # t_tiled_map[:, split].cumsum(0) - t_tiled_map[:, split], + # t_tiled_map[rank, split] - t_current_map[:, split] + 1, + # ), + # dim=1, + # ) + + # # col 0 = current offsets, col 1 = tiled offsets + # recv_rank = torch.where( + # 0 + # <= t_offset_map[:, 0] - t_offset_map[:, 1] + # <= t_tiled_map[:, split] - t_current_map[:, split] + 1 + # ) + + # # use distributed setitem! + # # then torch.repeat on non-distributed dimensions + # pass + + def topk(a, k, dim=None, largest=True, sorted=True, out=None): """ Returns the k highest entries in the array. From 0cda9e69bd4b2fb142562ff774eb19d119935bc4 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Tue, 8 Sep 2020 04:35:25 +0200 Subject: [PATCH 05/32] Fix output split in non_distributed tiling --- heat/core/manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 9a834b50fc..f442c9751b 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2182,7 +2182,7 @@ def tile(x, reps): # no repeats along the split axis: local operation t_x = x._DNDarray__array t_tiled = t_x.repeat(reps) - return factories.array(t_tiled, dtype=x.dtype, split=split, comm=x.comm) + return factories.array(t_tiled, dtype=x.dtype, is_split=split, comm=x.comm) else: raise NotImplementedError("ht.tile() not implemented yet for repeats along the split axis") # size = x.comm.Get_size() From 852eb2a6b3d61d0f899b60484ad5fd04a7de3d1c Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 11 Sep 2020 11:39:54 +0200 Subject: [PATCH 06/32] Implement distributed tile() --- heat/core/manipulations.py | 129 ++++++++++++++++++++++++++++--------- 1 file changed, 98 insertions(+), 31 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index f442c9751b..17a70ede6f 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2178,42 +2178,109 @@ def tile(x, reps): new_split = None if x.split is None else x.split + added_dims x = x.reshape(new_shape, axis=new_split) split = x.split + t_x = x._DNDarray__array if split is None or reps[split] == 1: # no repeats along the split axis: local operation - t_x = x._DNDarray__array t_tiled = t_x.repeat(reps) return factories.array(t_tiled, dtype=x.dtype, is_split=split, comm=x.comm) else: - raise NotImplementedError("ht.tile() not implemented yet for repeats along the split axis") - # size = x.comm.Get_size() - # rank = x.comm.Get_rank() - # # repeats along the split axis: communication needed - # output_shape = tuple(s * r for s, r in zip(x.gshape, reps)) - # tiled = factories.empty(output_shape, dtype=x.dtype, split=split, comm=x.comm) - # current_offset, current_lshape, current_slice = x.comm.chunk(x.gshape, split) - # tiled_offset, tiled_lshape, tiled_slice = tiled.comm.chunk(tiled.gshape, split) - # t_current_map = x.create_lshape_map() - # t_tiled_map = tiled.create_lshape_map() - # # map offsets (torch tensor with shape (size, 2) ) - # t_offset_map = torch.stack( - # ( - # t_current_map[:, split].cumsum(0) - t_current_map[:, split], - # t_tiled_map[:, split].cumsum(0) - t_tiled_map[:, split], - # t_tiled_map[rank, split] - t_current_map[:, split] + 1, - # ), - # dim=1, - # ) - - # # col 0 = current offsets, col 1 = tiled offsets - # recv_rank = torch.where( - # 0 - # <= t_offset_map[:, 0] - t_offset_map[:, 1] - # <= t_tiled_map[:, split] - t_current_map[:, split] + 1 - # ) - - # # use distributed setitem! - # # then torch.repeat on non-distributed dimensions - # pass + # repeats along the split axis: communication needed + size = x.comm.Get_size() + rank = x.comm.Get_rank() + x_shape = x.gshape + # allocate tiled DNDarray + tiled_shape = tuple(s * r for s, r in zip(x_shape, reps)) + tiled = factories.empty(tiled_shape, dtype=x.dtype, split=split, comm=x.comm) + # collect slicing information from all processes. + # t_(...) indicates process-local torch tensors + lshape_maps = [] + slices_map = [] + t_0 = torch.tensor([0], dtype=torch.int32) + for array in [x, tiled]: + t_lshape_map = array.create_lshape_map() + lshape_maps.append(t_lshape_map) + t_slices = torch.cat((t_0, t_lshape_map[:, split])).cumsum(0) + t_slices_starts = t_slices[:size] + t_slices_ends = t_slices[1:] + slices_map.append([t_slices_starts, t_slices_ends]) + + t_slices_x, t_slices_tiled = slices_map + + # keep track of repetitions + # t_x_starts.shape, t_x_ends.shape changing from (size,) to (reps[split], size) + reps_indices = list(x_shape[split] * rep for rep in (range(reps[split]))) + t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32).reshape(len(reps_indices), 1) + for i, t in enumerate(t_slices_x): + t = t.repeat((reps[split], 1)) + t += t_reps_indices + t_slices_x[i] = t + + # distribution logic on current rank: + distr_map = [] + slices_map = [] + for i in range(2): + if i == 0: + # send logic for x slices on rank + t_x_starts = t_slices_x[0][:, rank].reshape(reps[split], 1) + t_x_ends = t_slices_x[1][:, rank].reshape(reps[split], 1) + t_tiled_starts, t_tiled_ends = t_slices_tiled + else: + # recv logic for tiled slices on rank + t_x_starts, t_x_ends = t_slices_x + t_tiled_starts = t_slices_tiled[0][rank] + t_tiled_ends = t_slices_tiled[1][rank] + t_max_starts = torch.max(t_x_starts, t_tiled_starts) + t_min_ends = torch.min(t_x_ends, t_tiled_ends) + coords = torch.where(t_min_ends - t_max_starts > 0) + # remove repeat offset from slices if sending + if i == 0: + t_max_starts -= t_reps_indices + t_min_ends -= t_reps_indices + starts = t_max_starts[coords] + ends = t_min_ends[coords] + slices_map.append([starts, ends]) + distr_map.append(coords) + + send_map, recv_map = distr_map + send_slices, recv_slices = slices_map + send_to_ranks = send_map[1].tolist() + recv_from_ranks = recv_map[1].tolist() + + # allocate local buffer for incoming data + t_local_tile = torch.zeros( + tuple(lshape_maps[1][0].tolist()), dtype=x._DNDarray__array.dtype + ) + t_tiled = tiled._DNDarray__array + + send_slice = recv_slice = local_slice = tiled.ndim * (slice(None, None, None),) + offset_x, _, _ = x.comm.chunk(x.gshape, x.split) + offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split) + + for i, r in enumerate(send_to_ranks): + start = send_slices[0][i] - offset_x + stop = send_slices[1][i] - offset_x + send_slice = send_slice[:split] + (slice(start, stop, None),) + send_slice[split + 1 :] + local_slice = ( + local_slice[:split] + (slice(0, stop - start, None),) + local_slice[split + 1 :] + ) + t_local_tile[local_slice] = t_x[send_slice] + x.comm.Send(t_local_tile, r) + for i, r in enumerate(recv_from_ranks): + start = recv_slices[0][i] - offset_tiled + stop = recv_slices[1][i] - offset_tiled + recv_slice = recv_slice[:split] + (slice(start, stop, None),) + recv_slice[split + 1 :] + local_slice = ( + local_slice[:split] + (slice(0, stop - start, None),) + local_slice[split + 1 :] + ) + x.comm.Recv(t_local_tile, r) + tiled._DNDarray__array[recv_slice] = t_local_tile[local_slice] + + # finally tile along non-split axes if needed. Needs change in slices definition above + # reps = list(reps) + # reps[split] = 1 + # tiled._DNDarray__array = tiled._DNDarray__array.repeat(reps) + + return tiled def topk(a, k, dim=None, largest=True, sorted=True, out=None): From 475b8c63d2252404fd071a29a4040a8050892aba Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 16 Sep 2020 11:17:20 +0200 Subject: [PATCH 07/32] Implement distributed tiling along non-split axis if tiling along split axis. --- heat/core/manipulations.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 17a70ede6f..799a2b8eac 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2171,7 +2171,7 @@ def tile(x, reps): # x.dim >= 1 # calculate map of new gshape, lshape - + reps = list(reps) if len(reps) > x.ndim: added_dims = len(reps) - x.ndim new_shape = added_dims * (1,) + x.gshape @@ -2184,13 +2184,14 @@ def tile(x, reps): t_tiled = t_x.repeat(reps) return factories.array(t_tiled, dtype=x.dtype, is_split=split, comm=x.comm) else: - # repeats along the split axis: communication needed + # repeats along the split axis: communication of the split-axis repeats only size = x.comm.Get_size() rank = x.comm.Get_rank() x_shape = x.gshape - # allocate tiled DNDarray - tiled_shape = tuple(s * r for s, r in zip(x_shape, reps)) - tiled = factories.empty(tiled_shape, dtype=x.dtype, split=split, comm=x.comm) + # allocate tiled DNDarray, at first along split axis only + split_reps = [rep if i == split else 1 for i, rep in enumerate(reps)] + split_tiled_shape = tuple(s * r for s, r in zip(x_shape, split_reps)) + tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=split, comm=x.comm) # collect slicing information from all processes. # t_(...) indicates process-local torch tensors lshape_maps = [] @@ -2206,7 +2207,7 @@ def tile(x, reps): t_slices_x, t_slices_tiled = slices_map - # keep track of repetitions + # keep track of repetitions: # t_x_starts.shape, t_x_ends.shape changing from (size,) to (reps[split], size) reps_indices = list(x_shape[split] * rep for rep in (range(reps[split]))) t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32).reshape(len(reps_indices), 1) @@ -2246,15 +2247,16 @@ def tile(x, reps): send_to_ranks = send_map[1].tolist() recv_from_ranks = recv_map[1].tolist() - # allocate local buffer for incoming data + # allocate local buffers for incoming data t_local_tile = torch.zeros( tuple(lshape_maps[1][0].tolist()), dtype=x._DNDarray__array.dtype ) t_tiled = tiled._DNDarray__array - send_slice = recv_slice = local_slice = tiled.ndim * (slice(None, None, None),) - offset_x, _, _ = x.comm.chunk(x.gshape, x.split) - offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split) + # send_slice = recv_slice = local_slice = tiled.ndim * (slice(None, None, None),) + offset_x, _, send_slice = x.comm.chunk(x.gshape, x.split) + offset_tiled, _, recv_slice = tiled.comm.chunk(tiled.gshape, tiled.split) + local_slice = send_slice for i, r in enumerate(send_to_ranks): start = send_slices[0][i] - offset_x @@ -2273,13 +2275,11 @@ def tile(x, reps): local_slice[:split] + (slice(0, stop - start, None),) + local_slice[split + 1 :] ) x.comm.Recv(t_local_tile, r) - tiled._DNDarray__array[recv_slice] = t_local_tile[local_slice] - - # finally tile along non-split axes if needed. Needs change in slices definition above - # reps = list(reps) - # reps[split] = 1 - # tiled._DNDarray__array = tiled._DNDarray__array.repeat(reps) + t_tiled[recv_slice] = t_local_tile[local_slice] + # finally tile along non-split axes if needed + reps[split] = 1 + tiled = factories.array(t_tiled.repeat(reps), dtype=x.dtype, is_split=0, comm=x.comm) return tiled From 722bea67a0815d0032c35ae92fa97f233f2038f0 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 16 Sep 2020 12:22:36 +0200 Subject: [PATCH 08/32] Introduce sanitation module. --- heat/core/manipulations.py | 17 +++++++----- heat/core/sanitation.py | 53 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 heat/core/sanitation.py diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 799a2b8eac..d8a0c34f8f 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -12,6 +12,7 @@ from . import tiling from . import types from . import _operations +from . import sanitation __all__ = [ @@ -2166,12 +2167,16 @@ def tile(x, reps): If A.ndim > d, reps is promoted to A.ndim by pre-pending 1’s to it. Thus for an A of shape (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2). """ - # input sanitation - # x is DNDarray - # x.dim >= 1 - # calculate map of new gshape, lshape - reps = list(reps) + # check that input is DNDarray + sanitation.sanitize_input(x) + # check dimensions + if x.ndim == 0: + x = sanitation.to_1d(x) + # reps to list + reps = sanitation.sanitize_sequence(reps) + + # calculate new gshape, split if len(reps) > x.ndim: added_dims = len(reps) - x.ndim new_shape = added_dims * (1,) + x.gshape @@ -2193,7 +2198,7 @@ def tile(x, reps): split_tiled_shape = tuple(s * r for s, r in zip(x_shape, split_reps)) tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=split, comm=x.comm) # collect slicing information from all processes. - # t_(...) indicates process-local torch tensors + # "t_" indicates process-local torch tensors lshape_maps = [] slices_map = [] t_0 = torch.tensor([0], dtype=torch.int32) diff --git a/heat/core/sanitation.py b/heat/core/sanitation.py new file mode 100644 index 0000000000..5833340826 --- /dev/null +++ b/heat/core/sanitation.py @@ -0,0 +1,53 @@ +import numpy as np +import torch +import warnings + +from .communication import MPI + +from . import dndarray +from . import factories +from . import stride_tricks +from . import types + + +__all__ = ["sanitize_input", "sanitize_sequence", "to_1d"] + + +def sanitize_input(x): + """ + Raise TypeError if input is not DNDarray + """ + if not isinstance(x, dndarray.DNDarray): + raise TypeError("input must be a DNDarray, is {}".format(type(x))) + + +def sanitize_sequence(seq): + """ + if tuple, torch.tensor, dndarray --> return list + """ + if isinstance(seq, list): + return seq + elif isinstance(seq, tuple): + return list(seq) + elif isinstance(seq, dndarray.DNDarray): + if seq.split is None: + return seq._DNDarray__array.tolist() + else: + raise TypeError( + "seq is a distributed DNDarray, expected a list, a tuple, or a process-local array." + ) + elif isinstance(seq, torch.tensor): + return seq.tolist() + else: + raise TypeError( + "seq must be a list, a tuple, or a process-local array, got {}".format(type(seq)) + ) + + +def to_1d(x): + """ + Turn a scalar DNDarray into a 1-D DNDarray with 1 element. + """ + return factories.array( + x._DNDarray__array.unsqueeze(0), dtype=x.dtype, split=x.split, comm=x.comm + ) From dc0083e249400fbde88202decf574c5852ef684b Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 05:35:32 +0200 Subject: [PATCH 09/32] Implement tile case len(reps) < x.ndim --- heat/core/manipulations.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index b942b981c6..b840fcf13d 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2159,13 +2159,17 @@ def vstack(tup): def tile(x, reps): """ - Construct an array by repeating A the number of times given by reps. + Construct a new DNDarray by repeating 'x' the number of times given by 'reps'. - If reps has length d, the result will have dimension of max(d, A.ndim). + If 'reps' has length 'd', the result will have dimension of 'max(d, x.ndim)'. - If A.ndim < d, A is promoted to be d-dimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication. If this is not the desired behavior, promote A to d-dimensions manually before calling this function. + If 'x.ndim < d', 'x' is promoted to be d-dimensional by prepending new axes. + So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) + for 3-D replication. If this is not the desired behavior, promote 'x' to d-dimensions + manually before calling this function. - If A.ndim > d, reps is promoted to A.ndim by pre-pending 1’s to it. Thus for an A of shape (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2). + If 'x.ndim > d', 'reps' will replicate the last 'd' dimensions of 'x', i.e., if + 'x.shape' is (2, 3, 4, 5), a 'reps' of (2, 2) will be expanded to (1, 1, 2, 2). """ # check that input is DNDarray @@ -2177,12 +2181,16 @@ def tile(x, reps): reps = sanitation.sanitize_sequence(reps) # calculate new gshape, split - if len(reps) > x.ndim: - added_dims = len(reps) - x.ndim - new_shape = added_dims * (1,) + x.gshape - new_split = None if x.split is None else x.split + added_dims - x = x.reshape(new_shape, axis=new_split) + if len(reps) != x.ndim: + added_dims = abs(len(reps) - x.ndim) + if len(reps) > x.ndim: + new_shape = added_dims * (1,) + x.gshape + new_split = None if x.split is None else x.split + added_dims + x = x.reshape(new_shape, axis=new_split) + else: + reps = added_dims * [1] + reps split = x.split + # "t_" indicates process-local torch tensors t_x = x._DNDarray__array if split is None or reps[split] == 1: # no repeats along the split axis: local operation @@ -2193,12 +2201,11 @@ def tile(x, reps): size = x.comm.Get_size() rank = x.comm.Get_rank() x_shape = x.gshape - # allocate tiled DNDarray, at first along split axis only + # allocate tiled DNDarray, at first tiled along split axis only split_reps = [rep if i == split else 1 for i, rep in enumerate(reps)] split_tiled_shape = tuple(s * r for s, r in zip(x_shape, split_reps)) tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=split, comm=x.comm) # collect slicing information from all processes. - # "t_" indicates process-local torch tensors lshape_maps = [] slices_map = [] t_0 = torch.tensor([0], dtype=torch.int32) From a703bec7777318a2fdd71803ae93204bc35fedc2 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 09:58:11 +0200 Subject: [PATCH 10/32] Implement test_tile() --- heat/core/manipulations.py | 2 +- heat/core/sanitation.py | 2 +- heat/core/tests/test_manipulations.py | 30 +++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index b840fcf13d..7d9f00889f 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2290,7 +2290,7 @@ def tile(x, reps): # finally tile along non-split axes if needed reps[split] = 1 - tiled = factories.array(t_tiled.repeat(reps), dtype=x.dtype, is_split=0, comm=x.comm) + tiled = factories.array(t_tiled.repeat(reps), dtype=x.dtype, is_split=split, comm=x.comm) return tiled diff --git a/heat/core/sanitation.py b/heat/core/sanitation.py index 5833340826..cc3a5cb384 100644 --- a/heat/core/sanitation.py +++ b/heat/core/sanitation.py @@ -36,7 +36,7 @@ def sanitize_sequence(seq): raise TypeError( "seq is a distributed DNDarray, expected a list, a tuple, or a process-local array." ) - elif isinstance(seq, torch.tensor): + elif isinstance(seq, torch.Tensor): return seq.tolist() else: raise TypeError( diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 5243a583be..3b6a3918da 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1664,6 +1664,36 @@ def test_stack(self): ht.stack((ht_a_unbalanced, ht_b_split, ht_c_split)) # TODO test with DNDarrays on different devices + def test_tile(self): + # test local tile, tuple reps + x = ht.arange(12).reshape((4, 3)) + reps = (2, 1) + ht_tiled = ht.tile(x, reps) + np_tiled = np.tile(x.numpy(), reps) + self.assertTrue((np_tiled == ht_tiled.numpy()).all()) + self.assertTrue(ht_tiled.dtype is x.dtype) + + # test distributed tile along split axis + # reps is a DNDarray + # len(reps) > x.ndim + split = 0 + x = ht.random.randn(4, 3, split=split) + reps = ht.random.randint(2, 10, size=(4,)) + tiled_along_split = ht.tile(x, reps) + np_tiled_along_split = np.tile(x.numpy(), reps.numpy()) + self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) + self.assertTrue(tiled_along_split.dtype is x.dtype) + + # test tile along non-split axis + # len(reps) < x.ndim + split = 1 + x = ht.random.randn(4, 5, 3, 10, dtype=ht.float64, split=split) + reps = torch.tensor((2, 2)) + tiled_along_split = ht.tile(x, reps) + np_tiled_along_split = np.tile(x.numpy(), reps.numpy()) + self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) + self.assertTrue(tiled_along_split.dtype is x.dtype) + def test_topk(self): size = ht.MPI_WORLD.size if size == 1: From 0d1d6caf3a61fb0ef49d321927f5a38f7981a70a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 10:38:22 +0200 Subject: [PATCH 11/32] Check that reps is all integers, improve docs. --- heat/core/manipulations.py | 45 +++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 7d9f00889f..70ff962b03 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2161,15 +2161,48 @@ def tile(x, reps): """ Construct a new DNDarray by repeating 'x' the number of times given by 'reps'. - If 'reps' has length 'd', the result will have dimension of 'max(d, x.ndim)'. + If 'reps' has length 'd', the result will have 'max(d, x.ndim)' dimensions: - If 'x.ndim < d', 'x' is promoted to be d-dimensional by prepending new axes. + - if 'x.ndim < d', 'x' is promoted to be d-dimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) - for 3-D replication. If this is not the desired behavior, promote 'x' to d-dimensions - manually before calling this function. + for 3-D replication (if this is not the desired behavior, promote 'x' to d-dimensions + manually before calling this function); - If 'x.ndim > d', 'reps' will replicate the last 'd' dimensions of 'x', i.e., if + - if 'x.ndim > d', 'reps' will replicate the last 'd' dimensions of 'x', i.e., if 'x.shape' is (2, 3, 4, 5), a 'reps' of (2, 2) will be expanded to (1, 1, 2, 2). + + Parameters + ---------- + x : DNDarray + + reps : Sequence[ints,...] + + Returns + ------- + tiled : DNDarray + Split semantics: if `x` is distributed, the tiled data will be distributed along the + same dimension. Note that nominally `tiled.split != x.split` in the case where + `len(reps) > x.ndim`. See example below. + + Examples + -------- + >>> x = ht.arange(12).reshape((4,3)).resplit_(0) + >>> x + DNDarray([[ 0, 1, 2], + [ 3, 4, 5], + [ 6, 7, 8], + [ 9, 10, 11]], dtype=ht.int32, device=cpu:0, split=0) + >>> reps = (1, 2, 2) + >>> tiled = ht.tile(x, reps) + >>> tiled + DNDarray([[[ 0, 1, 2, 0, 1, 2], + [ 3, 4, 5, 3, 4, 5], + [ 6, 7, 8, 6, 7, 8], + [ 9, 10, 11, 9, 10, 11], + [ 0, 1, 2, 0, 1, 2], + [ 3, 4, 5, 3, 4, 5], + [ 6, 7, 8, 6, 7, 8], + [ 9, 10, 11, 9, 10, 11]]], dtype=ht.int32, device=cpu:0, split=1) """ # check that input is DNDarray @@ -2179,6 +2212,8 @@ def tile(x, reps): x = sanitation.to_1d(x) # reps to list reps = sanitation.sanitize_sequence(reps) + if not all(isinstance(rep, int) for rep in reps): + raise TypeError("reps must be a sequence of integers, got {}".format(reps)) # calculate new gshape, split if len(reps) != x.ndim: From 60cdd20d5cdca95cf40be332546512b9fe5df747 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 12:58:33 +0200 Subject: [PATCH 12/32] Increase test coverage --- heat/core/manipulations.py | 1 - heat/core/tests/test_manipulations.py | 5 +++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 70ff962b03..f7c5574d7a 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2204,7 +2204,6 @@ def tile(x, reps): [ 6, 7, 8, 6, 7, 8], [ 9, 10, 11, 9, 10, 11]]], dtype=ht.int32, device=cpu:0, split=1) """ - # check that input is DNDarray sanitation.sanitize_input(x) # check dimensions diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 3b6a3918da..e6576b0155 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1694,6 +1694,11 @@ def test_tile(self): self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) self.assertTrue(tiled_along_split.dtype is x.dtype) + # test exceptions + float_reps = (1, 2, 2, 1.5) + with self.assertRaises(TypeError): + tiled_along_split = ht.tile(x, float_reps) + def test_topk(self): size = ht.MPI_WORLD.size if size == 1: From 5ec8b8027bc593f5c5b8826a0eefeac7356ce7cd Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 13:33:35 +0200 Subject: [PATCH 13/32] Implement test_sanitation --- heat/core/__init__.py | 1 + heat/core/manipulations.py | 2 +- heat/core/sanitation.py | 15 +++++++++------ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/heat/core/__init__.py b/heat/core/__init__.py index 150ee331ad..1c1ec90a42 100644 --- a/heat/core/__init__.py +++ b/heat/core/__init__.py @@ -15,6 +15,7 @@ from . import random from .relational import * from .rounding import * +from .sanitation import * from .statistics import * from .dndarray import * from .tiling import * diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index f7c5574d7a..052a976ca6 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2208,7 +2208,7 @@ def tile(x, reps): sanitation.sanitize_input(x) # check dimensions if x.ndim == 0: - x = sanitation.to_1d(x) + x = sanitation.scalar_to_1d(x) # reps to list reps = sanitation.sanitize_sequence(reps) if not all(isinstance(rep, int) for rep in reps): diff --git a/heat/core/sanitation.py b/heat/core/sanitation.py index cc3a5cb384..8f6e3d26e9 100644 --- a/heat/core/sanitation.py +++ b/heat/core/sanitation.py @@ -10,7 +10,7 @@ from . import types -__all__ = ["sanitize_input", "sanitize_sequence", "to_1d"] +__all__ = ["sanitize_input", "sanitize_sequence", "scalar_to_1d"] def sanitize_input(x): @@ -33,7 +33,7 @@ def sanitize_sequence(seq): if seq.split is None: return seq._DNDarray__array.tolist() else: - raise TypeError( + raise ValueError( "seq is a distributed DNDarray, expected a list, a tuple, or a process-local array." ) elif isinstance(seq, torch.Tensor): @@ -44,10 +44,13 @@ def sanitize_sequence(seq): ) -def to_1d(x): +def scalar_to_1d(x): """ Turn a scalar DNDarray into a 1-D DNDarray with 1 element. """ - return factories.array( - x._DNDarray__array.unsqueeze(0), dtype=x.dtype, split=x.split, comm=x.comm - ) + if x.ndim == 0: + return factories.array( + x._DNDarray__array.unsqueeze(0), dtype=x.dtype, split=x.split, comm=x.comm + ) + else: + raise ValueError("expected a DNDarray scalar, got DNDarray with shape {}".format(x.shape)) From ba93fd285a923b6a0aa994cab52431fbdd344a84 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 15:35:45 +0200 Subject: [PATCH 14/32] Expand sanitation docs --- heat/core/sanitation.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/heat/core/sanitation.py b/heat/core/sanitation.py index 8f6e3d26e9..34dbf58202 100644 --- a/heat/core/sanitation.py +++ b/heat/core/sanitation.py @@ -23,7 +23,15 @@ def sanitize_input(x): def sanitize_sequence(seq): """ - if tuple, torch.tensor, dndarray --> return list + Check if sequence is valid, return list. + + Parameters + ---------- + seq : Union[Sequence[ints, ...], Sequence[floats, ...], DNDarray, torch.tensor] + + Returns + ------- + seq : List """ if isinstance(seq, list): return seq @@ -47,10 +55,17 @@ def sanitize_sequence(seq): def scalar_to_1d(x): """ Turn a scalar DNDarray into a 1-D DNDarray with 1 element. + + Parameters + ---------- + x : DNDarray + with `x.ndim = 0` + + Returns + ------- + x : DNDarray + where `x.ndim = 1` and `x.shape = (1,)` """ - if x.ndim == 0: - return factories.array( - x._DNDarray__array.unsqueeze(0), dtype=x.dtype, split=x.split, comm=x.comm - ) - else: - raise ValueError("expected a DNDarray scalar, got DNDarray with shape {}".format(x.shape)) + return factories.array( + x._DNDarray__array.unsqueeze(0), dtype=x.dtype, split=x.split, comm=x.comm + ) From 53320fd825eefdd805604423823dde17c633767a Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 15:47:31 +0200 Subject: [PATCH 15/32] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 878e74a12f..95cf972af8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ - [#664](https://github.com/helmholtz-analytics/heat/pull/664) New feature / enhancement: `random.random_sample`, `random.random`, `random.sample`, `random.ranf`, `random.random_integer` - [#666](https://github.com/helmholtz-analytics/heat/pull/666) New feature: distributed prepend/append for diff(). - [#667](https://github.com/helmholtz-analytics/heat/pull/667) Enhancement `reshape`: rename axis parameter +- [#678](https://github.com/helmholtz-analytics/heat/pull/678) New feature: distributed `tile` # v0.4.0 From 4e4bbb42039da7b2521c186b266f0f3c5c867034 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 16:05:16 +0200 Subject: [PATCH 16/32] Increase test coverage --- heat/core/tests/test_manipulations.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index e6576b0155..df4ddace9d 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -1673,6 +1673,14 @@ def test_tile(self): self.assertTrue((np_tiled == ht_tiled.numpy()).all()) self.assertTrue(ht_tiled.dtype is x.dtype) + # test scalar x + x = ht.array(9.0) + reps = (2, 1) + ht_tiled = ht.tile(x, reps) + np_tiled = np.tile(x.numpy(), reps) + self.assertTrue((np_tiled == ht_tiled.numpy()).all()) + self.assertTrue(ht_tiled.dtype is x.dtype) + # test distributed tile along split axis # reps is a DNDarray # len(reps) > x.ndim From 205bf76c946222cef2133c08022c926a6921fe82 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 16:33:57 +0200 Subject: [PATCH 17/32] Set device for stand-alone torch tensor --- heat/core/manipulations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 052a976ca6..54ec229244 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2242,7 +2242,7 @@ def tile(x, reps): # collect slicing information from all processes. lshape_maps = [] slices_map = [] - t_0 = torch.tensor([0], dtype=torch.int32) + t_0 = torch.tensor([0], dtype=torch.int32, device=t_x.device) for array in [x, tiled]: t_lshape_map = array.create_lshape_map() lshape_maps.append(t_lshape_map) From 5059726c6ec4858c4ceb9630339faafbd1fb1b29 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Thu, 17 Sep 2020 16:39:16 +0200 Subject: [PATCH 18/32] Fix torch device --- heat/core/manipulations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 54ec229244..85c1d91d3f 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -2256,7 +2256,9 @@ def tile(x, reps): # keep track of repetitions: # t_x_starts.shape, t_x_ends.shape changing from (size,) to (reps[split], size) reps_indices = list(x_shape[split] * rep for rep in (range(reps[split]))) - t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32).reshape(len(reps_indices), 1) + t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32, device=t_x.device).reshape( + len(reps_indices), 1 + ) for i, t in enumerate(t_slices_x): t = t.repeat((reps[split], 1)) t += t_reps_indices From 67802cdac5bf1adfa79718f76ae4e0780ea43c72 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Mon, 21 Sep 2020 17:06:05 +0200 Subject: [PATCH 19/32] black --- heat/core/sanitation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/heat/core/sanitation.py b/heat/core/sanitation.py index 78e00de15a..b839f5f552 100644 --- a/heat/core/sanitation.py +++ b/heat/core/sanitation.py @@ -10,7 +10,6 @@ from . import types - __all__ = ["sanitize_input", "sanitize_out", "sanitize_sequence", "scalar_to_1d"] From 39c8ec92c39298b0f66a8ecb801645706aef9542 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 11:15:32 +0200 Subject: [PATCH 20/32] distributed ht.tile() overhaul via Alltoallv --- heat/core/manipulations.py | 93 +++++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 37 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 73d1688ff5..c24a7d3e3d 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3305,7 +3305,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: [ 9, 10, 11, 9, 10, 11]]], dtype=ht.int32, device=cpu:0, split=1) """ # check that input is DNDarray - sanitation.sanitize_input(x) + sanitation.sanitize_in(x) # check dimensions if x.ndim == 0: x = sanitation.scalar_to_1d(x) @@ -3325,7 +3325,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: reps = added_dims * [1] + reps split = x.split # "t_" indicates process-local torch tensors - t_x = x._DNDarray__array + t_x = x.larray if split is None or reps[split] == 1: # no repeats along the split axis: local operation t_tiled = t_x.repeat(reps) @@ -3342,13 +3342,13 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: # collect slicing information from all processes. lshape_maps = [] slices_map = [] - t_0 = torch.tensor([0], dtype=torch.int32, device=t_x.device) for array in [x, tiled]: + counts, displs = array.counts_displs() + t_slices_starts = torch.tensor(displs, device=t_x.device) + t_slices_ends = t_slices_starts + torch.tensor(counts, device=t_x.device) + # TODO: replace following with lshape_map property when available t_lshape_map = array.create_lshape_map() lshape_maps.append(t_lshape_map) - t_slices = torch.cat((t_0, t_lshape_map[:, split])).cumsum(0) - t_slices_starts = t_slices[:size] - t_slices_ends = t_slices[1:] slices_map.append([t_slices_starts, t_slices_ends]) t_slices_x, t_slices_tiled = slices_map @@ -3385,44 +3385,63 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: if i == 0: t_max_starts -= t_reps_indices t_min_ends -= t_reps_indices - starts = t_max_starts[coords] - ends = t_min_ends[coords] - slices_map.append([starts, ends]) + starts = t_max_starts[coords].unsqueeze_(0) + ends = t_min_ends[coords].unsqueeze_(0) + slices_map.append(torch.cat((starts, ends), dim=0)) distr_map.append(coords) + # bookkeeping in preparation for Alltoallv send_map, recv_map = distr_map + send_rep, send_to_ranks = send_map + recv_rep, recv_from_ranks = recv_map send_slices, recv_slices = slices_map - send_to_ranks = send_map[1].tolist() - recv_from_ranks = recv_map[1].tolist() - # allocate local buffers for incoming data - t_local_tile = torch.zeros( - tuple(lshape_maps[1][0].tolist()), dtype=x._DNDarray__array.dtype - ) - t_tiled = tiled._DNDarray__array - - offset_x, _, send_slice = x.comm.chunk(x.gshape, x.split) - offset_tiled, _, recv_slice = tiled.comm.chunk(tiled.gshape, tiled.split) - local_slice = send_slice - - for i, r in enumerate(send_to_ranks): - start = send_slices[0][i] - offset_x - stop = send_slices[1][i] - offset_x - send_slice = send_slice[:split] + (slice(start, stop, None),) + send_slice[split + 1 :] - local_slice = ( - local_slice[:split] + (slice(0, stop - start, None),) + local_slice[split + 1 :] + offset_x, _, _ = x.comm.chunk(x.gshape, x.split) + offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split) + t_tiled = tiled.larray + + active_send_counts = send_slices.clone() + active_send_counts[0] *= -1 + active_send_counts = active_send_counts.sum(0) + active_recv_counts = recv_slices.clone() + active_recv_counts[0] *= -1 + active_recv_counts = active_recv_counts.sum(0) + send_slices -= offset_x + recv_slices -= offset_tiled + recv_data = t_tiled.clone() + # we need as many Alltoallv calls as repeats along the split axis + for rep in range(reps[split]): + # send_data, send_counts, send_displs on rank + all_send_counts = [0] * size + all_send_displs = [0] * size + send_this_rep = torch.where(send_rep == rep)[0].tolist() + dest_this_rep = send_to_ranks[send_this_rep].tolist() + for i, j in zip(send_this_rep, dest_this_rep): + all_send_counts[j] = active_send_counts[i].item() + all_send_displs[j] = send_slices[0][i].item() + local_send_slice = [slice(None)] * x.ndim + local_send_slice[split] = slice( + all_send_displs[0], all_send_displs[0] + sum(all_send_counts) + ) + send_data = t_x[local_send_slice].clone() + + # recv_data, recv_counts, recv_displs on rank + all_recv_counts = [0] * size + all_recv_displs = [0] * size + recv_this_rep = torch.where(recv_rep == rep)[0].tolist() + orig_this_rep = recv_from_ranks[recv_this_rep].tolist() + for i, j in zip(recv_this_rep, orig_this_rep): + all_recv_counts[j] = active_recv_counts[i].item() + all_recv_displs[j] = recv_slices[0][i].item() + local_recv_slice = [slice(None)] * x.ndim + local_recv_slice[split] = slice( + all_recv_displs[0], all_recv_displs[0] + sum(all_recv_counts) ) - t_local_tile[local_slice] = t_x[send_slice] - x.comm.Send(t_local_tile, r) - for i, r in enumerate(recv_from_ranks): - start = recv_slices[0][i] - offset_tiled - stop = recv_slices[1][i] - offset_tiled - recv_slice = recv_slice[:split] + (slice(start, stop, None),) + recv_slice[split + 1 :] - local_slice = ( - local_slice[:split] + (slice(0, stop - start, None),) + local_slice[split + 1 :] + x.comm.Alltoallv( + (send_data, all_send_counts, all_send_displs), + (recv_data, all_recv_counts, all_recv_displs), ) - x.comm.Recv(t_local_tile, r) - t_tiled[recv_slice] = t_local_tile[local_slice] + t_tiled[local_recv_slice] = recv_data[local_recv_slice] # finally tile along non-split axes if needed reps[split] = 1 From dfdf9b3be1c1d813456dafbd13f1691f56b460d8 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 11:22:08 +0200 Subject: [PATCH 21/32] Remove dead code --- heat/core/manipulations.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index c24a7d3e3d..49b315dd66 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3340,15 +3340,11 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: split_tiled_shape = tuple(s * r for s, r in zip(x_shape, split_reps)) tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=split, comm=x.comm) # collect slicing information from all processes. - lshape_maps = [] slices_map = [] for array in [x, tiled]: counts, displs = array.counts_displs() t_slices_starts = torch.tensor(displs, device=t_x.device) t_slices_ends = t_slices_starts + torch.tensor(counts, device=t_x.device) - # TODO: replace following with lshape_map property when available - t_lshape_map = array.create_lshape_map() - lshape_maps.append(t_lshape_map) slices_map.append([t_slices_starts, t_slices_ends]) t_slices_x, t_slices_tiled = slices_map From 117db7c0a57746eadee9ad0f95deb211eb8b63e8 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 11:29:25 +0200 Subject: [PATCH 22/32] Fix tests --- heat/core/tests/test_manipulations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index dd3f76288f..433c059ac8 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3068,9 +3068,9 @@ def test_tile(self): # len(reps) > x.ndim split = 0 x = ht.random.randn(4, 3, split=split) - reps = ht.random.randint(2, 10, size=(4,)) + reps = ht.random.randint(2, 10, size=(4,)).tolist() tiled_along_split = ht.tile(x, reps) - np_tiled_along_split = np.tile(x.numpy(), reps.numpy()) + np_tiled_along_split = np.tile(x.numpy(), reps) self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) self.assertTrue(tiled_along_split.dtype is x.dtype) @@ -3078,9 +3078,9 @@ def test_tile(self): # len(reps) < x.ndim split = 1 x = ht.random.randn(4, 5, 3, 10, dtype=ht.float64, split=split) - reps = torch.tensor((2, 2)) + reps = (2, 2) tiled_along_split = ht.tile(x, reps) - np_tiled_along_split = np.tile(x.numpy(), reps.numpy()) + np_tiled_along_split = np.tile(x.numpy(), reps) self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) self.assertTrue(tiled_along_split.dtype is x.dtype) From 6901f108ba511a5880145344b44bb48d5943c1d4 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 11:39:04 +0200 Subject: [PATCH 23/32] clean up --- heat/core/manipulations.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 49b315dd66..8fcb0ffaa5 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3404,7 +3404,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: active_recv_counts = active_recv_counts.sum(0) send_slices -= offset_x recv_slices -= offset_tiled - recv_data = t_tiled.clone() + recv_buf = t_tiled.clone() # we need as many Alltoallv calls as repeats along the split axis for rep in range(reps[split]): # send_data, send_counts, send_displs on rank @@ -3419,7 +3419,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: local_send_slice[split] = slice( all_send_displs[0], all_send_displs[0] + sum(all_send_counts) ) - send_data = t_x[local_send_slice].clone() + send_buf = t_x[local_send_slice].clone() # recv_data, recv_counts, recv_displs on rank all_recv_counts = [0] * size @@ -3434,10 +3434,10 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: all_recv_displs[0], all_recv_displs[0] + sum(all_recv_counts) ) x.comm.Alltoallv( - (send_data, all_send_counts, all_send_displs), - (recv_data, all_recv_counts, all_recv_displs), + (send_buf, all_send_counts, all_send_displs), + (recv_buf, all_recv_counts, all_recv_displs), ) - t_tiled[local_recv_slice] = recv_data[local_recv_slice] + t_tiled[local_recv_slice] = recv_buf[local_recv_slice] # finally tile along non-split axes if needed reps[split] = 1 From 808b32023fc10b88d2064cc780ef76940317e335 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 12:31:55 +0200 Subject: [PATCH 24/32] Cast out_gshape to tuple after torch multiplication --- heat/core/manipulations.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 8fcb0ffaa5..fbc272a2ea 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3326,10 +3326,22 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: split = x.split # "t_" indicates process-local torch tensors t_x = x.larray + out_gshape = tuple( + torch.tensor(x.shape, device=t_x.device) * torch.tensor(reps, device=t_x.device) + ) if split is None or reps[split] == 1: # no repeats along the split axis: local operation t_tiled = t_x.repeat(reps) - return factories.array(t_tiled, dtype=x.dtype, is_split=split, comm=x.comm) + # return factories.array(t_tiled, dtype=x.dtype, is_split=split, comm=x.comm) + return DNDarray( + t_tiled, + out_gshape, + dtype=x.dtype, + split=split, + device=x.device, + comm=x.comm, + balanced=x.balanced, + ) else: # repeats along the split axis: communication of the split-axis repeats only size = x.comm.Get_size() @@ -3441,8 +3453,15 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: # finally tile along non-split axes if needed reps[split] = 1 - tiled = factories.array(t_tiled.repeat(reps), dtype=x.dtype, is_split=split, comm=x.comm) - return tiled + return DNDarray( + t_tiled.repeat(reps), + out_gshape, + dtype=x.dtype, + split=split, + device=x.device, + comm=x.comm, + balanced=True, + ) def topk( From ef841997b5a9442ea30a09de0f93e93fa2235865 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 12:50:18 +0200 Subject: [PATCH 25/32] Do not assume input is balanced --- heat/core/manipulations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index fbc272a2ea..bb47d6d4b0 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3332,7 +3332,6 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: if split is None or reps[split] == 1: # no repeats along the split axis: local operation t_tiled = t_x.repeat(reps) - # return factories.array(t_tiled, dtype=x.dtype, is_split=split, comm=x.comm) return DNDarray( t_tiled, out_gshape, @@ -3404,7 +3403,10 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: recv_rep, recv_from_ranks = recv_map send_slices, recv_slices = slices_map - offset_x, _, _ = x.comm.chunk(x.gshape, x.split) + # do not assume that `x` is balanced + _, displs = x.counts_displs() + offset_x = displs[rank] + # impose load-balance on output offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split) t_tiled = tiled.larray From 900c8cc2d405e083589980d1ccca4fd6a77025a5 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 12:51:02 +0200 Subject: [PATCH 26/32] Test imbalanced input case --- heat/core/tests/test_manipulations.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 433c059ac8..011866e736 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3064,7 +3064,6 @@ def test_tile(self): self.assertTrue(ht_tiled.dtype is x.dtype) # test distributed tile along split axis - # reps is a DNDarray # len(reps) > x.ndim split = 0 x = ht.random.randn(4, 3, split=split) @@ -3074,6 +3073,18 @@ def test_tile(self): self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) self.assertTrue(tiled_along_split.dtype is x.dtype) + # test distributed tile() on imbalanced DNDarray + x = ht.random.randn(100, split=split) + x = x[ht.where(x > 0)] + reps = (5,) + imbalanced_tiled_along_split = ht.tile(x, reps) + np_imbalanced_tiled_along_split = np.tile(x.numpy(), reps) + self.assertTrue( + (imbalanced_tiled_along_split.numpy() == np_imbalanced_tiled_along_split).all() + ) + self.assertTrue(imbalanced_tiled_along_split.dtype is x.dtype) + self.assertTrue(imbalanced_tiled_along_split.is_balanced(force_check=True)) + # test tile along non-split axis # len(reps) < x.ndim split = 1 From 0c20652553c0f66ff1f763e10328e515db7ba9c4 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 13:50:29 +0200 Subject: [PATCH 27/32] operate on transposed data if original split is not 0 --- heat/core/manipulations.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index bb47d6d4b0..1b1424aed7 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3342,9 +3342,19 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: balanced=x.balanced, ) else: - # repeats along the split axis: communication of the split-axis repeats only + # repeats along the split axis size = x.comm.Get_size() rank = x.comm.Get_rank() + # make sure we work along dim 0 + trans_axes = list(range(x.ndim)) + if split != 0: + trans_axes[0], trans_axes[split] = split, 0 + reps[0], reps[split] = reps[split], reps[0] + x = linalg.transpose(x, trans_axes) + split = 0 + out_gshape = tuple( + torch.tensor(x.shape, device=t_x.device) * torch.tensor(reps, device=t_x.device) + ) x_shape = x.gshape # allocate tiled DNDarray, at first tiled along split axis only split_reps = [rep if i == split else 1 for i, rep in enumerate(reps)] @@ -3455,7 +3465,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: # finally tile along non-split axes if needed reps[split] = 1 - return DNDarray( + tiled = DNDarray( t_tiled.repeat(reps), out_gshape, dtype=x.dtype, @@ -3464,6 +3474,12 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: comm=x.comm, balanced=True, ) + if trans_axes != list(range(x.ndim)): + # transpose back to original shape + x = linalg.transpose(x, trans_axes) + tiled = linalg.transpose(tiled, trans_axes) + + return tiled def topk( From 2531a7cb29be2b77273510c94f3b0a1a6fc819d5 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 13:51:38 +0200 Subject: [PATCH 28/32] Test for non-zero split as well --- heat/core/tests/test_manipulations.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index 011866e736..6405c4f6aa 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3065,6 +3065,16 @@ def test_tile(self): # test distributed tile along split axis # len(reps) > x.ndim + split = 1 + x = ht.random.randn(4, 3, split=split) + reps = ht.random.randint(2, 10, size=(4,)).tolist() + tiled_along_split = ht.tile(x, reps) + np_tiled_along_split = np.tile(x.numpy(), reps) + self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) + self.assertTrue(tiled_along_split.dtype is x.dtype) + + # test distributed tile along non-zero split axis + # len(reps) > x.ndim split = 0 x = ht.random.randn(4, 3, split=split) reps = ht.random.randint(2, 10, size=(4,)).tolist() @@ -3074,7 +3084,7 @@ def test_tile(self): self.assertTrue(tiled_along_split.dtype is x.dtype) # test distributed tile() on imbalanced DNDarray - x = ht.random.randn(100, split=split) + x = ht.random.randn(100, split=0) x = x[ht.where(x > 0)] reps = (5,) imbalanced_tiled_along_split = ht.tile(x, reps) From 9b8dfaf515b8618a92b6525dd4d2397acf674768 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Fri, 9 Jul 2021 15:16:27 +0200 Subject: [PATCH 29/32] Fix out_gshape (tuple of ints, not tuple of torch tensors) --- heat/core/manipulations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 1b1424aed7..1b39fcbd3e 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3327,7 +3327,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: # "t_" indicates process-local torch tensors t_x = x.larray out_gshape = tuple( - torch.tensor(x.shape, device=t_x.device) * torch.tensor(reps, device=t_x.device) + (torch.tensor(x.shape, device=t_x.device) * torch.tensor(reps, device=t_x.device)).tolist() ) if split is None or reps[split] == 1: # no repeats along the split axis: local operation @@ -3353,7 +3353,9 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: x = linalg.transpose(x, trans_axes) split = 0 out_gshape = tuple( - torch.tensor(x.shape, device=t_x.device) * torch.tensor(reps, device=t_x.device) + ( + torch.tensor(x.shape, device=t_x.device) * torch.tensor(reps, device=t_x.device) + ).tolist() ) x_shape = x.gshape # allocate tiled DNDarray, at first tiled along split axis only From 5a4f3f70b6300897a7692021aa3cd25b742ed01b Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 11 Aug 2021 11:33:21 +0200 Subject: [PATCH 30/32] Improve sanitation logic, fix bug occurring when `x.split != 0` --- heat/core/manipulations.py | 138 ++++++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 55 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index eb878d2c7f..5858330e49 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3530,82 +3530,110 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: [ 6, 7, 8, 6, 7, 8], [ 9, 10, 11, 9, 10, 11]]], dtype=ht.int32, device=cpu:0, split=1) """ - # check that input is DNDarray - sanitation.sanitize_in(x) - # check dimensions - if x.ndim == 0: - x = sanitation.scalar_to_1d(x) - # reps to list - reps = sanitation.sanitize_sequence(reps) - if not all(isinstance(rep, int) for rep in reps): - raise TypeError("reps must be a sequence of integers, got {}".format(reps)) - - # calculate new gshape, split + # x can be DNDarray or scalar + try: + _ = x.larray + except AttributeError: + try: + _ = x.shape + raise TypeError("Input can be a DNDarray or a scalar, is {}".format(type(x))) + except AttributeError: + x = factories.array(x).reshape(1) + + x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim) + + # torch-proof args/kwargs: + # torch `reps`: int or sequence of ints; numpy `reps`: can be array-like + try: + _ = x_proxy.repeat(reps) + except TypeError: + # `reps` is array-like or contains non-int elements + try: + reps = resplit(reps, None).tolist() + except AttributeError: + try: + reps = reps.tolist() + except AttributeError: + try: + _ = x_proxy.repeat(reps) + except TypeError: + raise TypeError( + "reps must be a sequence of ints, got {}".format( + list(type(i) for i in reps) + ) + ) + except RuntimeError: + pass + except RuntimeError: + pass + + try: + reps = list(reps) + except TypeError: + # scalar to list + reps = [reps] + + # torch reps vs. numpy reps: dimensions if len(reps) != x.ndim: added_dims = abs(len(reps) - x.ndim) if len(reps) > x.ndim: new_shape = added_dims * (1,) + x.gshape new_split = None if x.split is None else x.split + added_dims - x = x.reshape(new_shape, axis=new_split) + x = x.reshape(new_shape, new_split=new_split) else: reps = added_dims * [1] + reps - split = x.split - # "t_" indicates process-local torch tensors - t_x = x.larray - out_gshape = tuple( - (torch.tensor(x.shape, device=t_x.device) * torch.tensor(reps, device=t_x.device)).tolist() - ) - if split is None or reps[split] == 1: + + out_gshape = tuple(x_proxy.repeat(reps).shape) + + if not x.is_distributed() or reps[x.split] == 1: # no repeats along the split axis: local operation - t_tiled = t_x.repeat(reps) + t_tiled = x.larray.repeat(reps) + out_gshape = tuple(x_proxy.repeat(reps).shape) return DNDarray( t_tiled, out_gshape, dtype=x.dtype, - split=split, + split=x.split, device=x.device, comm=x.comm, balanced=x.balanced, ) else: - # repeats along the split axis + # repeats along the split axis, work along dim 0 size = x.comm.Get_size() rank = x.comm.Get_rank() - # make sure we work along dim 0 trans_axes = list(range(x.ndim)) - if split != 0: - trans_axes[0], trans_axes[split] = split, 0 - reps[0], reps[split] = reps[split], reps[0] + if x.split != 0: + trans_axes[0], trans_axes[x.split] = x.split, 0 + reps[0], reps[x.split] = reps[x.split], reps[0] x = linalg.transpose(x, trans_axes) - split = 0 - out_gshape = tuple( - ( - torch.tensor(x.shape, device=t_x.device) * torch.tensor(reps, device=t_x.device) - ).tolist() - ) - x_shape = x.gshape + x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim) + out_gshape = tuple(x_proxy.repeat(reps).shape) + + local_x = x.larray + # allocate tiled DNDarray, at first tiled along split axis only - split_reps = [rep if i == split else 1 for i, rep in enumerate(reps)] - split_tiled_shape = tuple(s * r for s, r in zip(x_shape, split_reps)) - tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=split, comm=x.comm) + split_reps = [rep if i == x.split else 1 for i, rep in enumerate(reps)] + split_tiled_shape = tuple(x_proxy.repeat(split_reps).shape) + tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=x.split, comm=x.comm) # collect slicing information from all processes. slices_map = [] for array in [x, tiled]: counts, displs = array.counts_displs() - t_slices_starts = torch.tensor(displs, device=t_x.device) - t_slices_ends = t_slices_starts + torch.tensor(counts, device=t_x.device) + t_slices_starts = torch.tensor(displs, device=local_x.device) + t_slices_ends = t_slices_starts + torch.tensor(counts, device=local_x.device) slices_map.append([t_slices_starts, t_slices_ends]) t_slices_x, t_slices_tiled = slices_map # keep track of repetitions: - # t_x_starts.shape, t_x_ends.shape changing from (size,) to (reps[split], size) - reps_indices = list(x_shape[split] * rep for rep in (range(reps[split]))) - t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32, device=t_x.device).reshape( - len(reps_indices), 1 - ) + # local_x_starts.shape, local_x_ends.shape changing from (size,) to (reps[split], size) + reps_indices = list(x.gshape[x.split] * rep for rep in (range(reps[x.split]))) + t_reps_indices = torch.tensor( + reps_indices, dtype=torch.int32, device=local_x.device + ).reshape(len(reps_indices), 1) for i, t in enumerate(t_slices_x): - t = t.repeat((reps[split], 1)) + t = t.repeat((reps[x.split], 1)) t += t_reps_indices t_slices_x[i] = t @@ -3615,16 +3643,16 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: for i in range(2): if i == 0: # send logic for x slices on rank - t_x_starts = t_slices_x[0][:, rank].reshape(reps[split], 1) - t_x_ends = t_slices_x[1][:, rank].reshape(reps[split], 1) + local_x_starts = t_slices_x[0][:, rank].reshape(reps[x.split], 1) + local_x_ends = t_slices_x[1][:, rank].reshape(reps[x.split], 1) t_tiled_starts, t_tiled_ends = t_slices_tiled else: # recv logic for tiled slices on rank - t_x_starts, t_x_ends = t_slices_x + local_x_starts, local_x_ends = t_slices_x t_tiled_starts = t_slices_tiled[0][rank] t_tiled_ends = t_slices_tiled[1][rank] - t_max_starts = torch.max(t_x_starts, t_tiled_starts) - t_min_ends = torch.min(t_x_ends, t_tiled_ends) + t_max_starts = torch.max(local_x_starts, t_tiled_starts) + t_min_ends = torch.min(local_x_ends, t_tiled_ends) coords = torch.where(t_min_ends - t_max_starts > 0) # remove repeat offset from slices if sending if i == 0: @@ -3658,7 +3686,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: recv_slices -= offset_tiled recv_buf = t_tiled.clone() # we need as many Alltoallv calls as repeats along the split axis - for rep in range(reps[split]): + for rep in range(reps[x.split]): # send_data, send_counts, send_displs on rank all_send_counts = [0] * size all_send_displs = [0] * size @@ -3668,10 +3696,10 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: all_send_counts[j] = active_send_counts[i].item() all_send_displs[j] = send_slices[0][i].item() local_send_slice = [slice(None)] * x.ndim - local_send_slice[split] = slice( + local_send_slice[x.split] = slice( all_send_displs[0], all_send_displs[0] + sum(all_send_counts) ) - send_buf = t_x[local_send_slice].clone() + send_buf = local_x[local_send_slice].clone() # recv_data, recv_counts, recv_displs on rank all_recv_counts = [0] * size @@ -3682,7 +3710,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: all_recv_counts[j] = active_recv_counts[i].item() all_recv_displs[j] = recv_slices[0][i].item() local_recv_slice = [slice(None)] * x.ndim - local_recv_slice[split] = slice( + local_recv_slice[x.split] = slice( all_recv_displs[0], all_recv_displs[0] + sum(all_recv_counts) ) x.comm.Alltoallv( @@ -3692,12 +3720,12 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: t_tiled[local_recv_slice] = recv_buf[local_recv_slice] # finally tile along non-split axes if needed - reps[split] = 1 + reps[x.split] = 1 tiled = DNDarray( t_tiled.repeat(reps), out_gshape, dtype=x.dtype, - split=split, + split=x.split, device=x.device, comm=x.comm, balanced=True, From a6921336390d39f087d8b660566325cf269e0956 Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 11 Aug 2021 11:35:19 +0200 Subject: [PATCH 31/32] Expand tests for array-like reps --- heat/core/tests/test_manipulations.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/heat/core/tests/test_manipulations.py b/heat/core/tests/test_manipulations.py index c5e80b63e1..91f9453f15 100644 --- a/heat/core/tests/test_manipulations.py +++ b/heat/core/tests/test_manipulations.py @@ -3261,9 +3261,9 @@ def test_tile(self): # len(reps) > x.ndim split = 1 x = ht.random.randn(4, 3, split=split) - reps = ht.random.randint(2, 10, size=(4,)).tolist() + reps = ht.random.randint(2, 10, size=(4,)) tiled_along_split = ht.tile(x, reps) - np_tiled_along_split = np.tile(x.numpy(), reps) + np_tiled_along_split = np.tile(x.numpy(), reps.tolist()) self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) self.assertTrue(tiled_along_split.dtype is x.dtype) @@ -3271,7 +3271,7 @@ def test_tile(self): # len(reps) > x.ndim split = 0 x = ht.random.randn(4, 3, split=split) - reps = ht.random.randint(2, 10, size=(4,)).tolist() + reps = np.random.randint(2, 10, size=(4,)) tiled_along_split = ht.tile(x, reps) np_tiled_along_split = np.tile(x.numpy(), reps) self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) @@ -3280,7 +3280,7 @@ def test_tile(self): # test distributed tile() on imbalanced DNDarray x = ht.random.randn(100, split=0) x = x[ht.where(x > 0)] - reps = (5,) + reps = 5 imbalanced_tiled_along_split = ht.tile(x, reps) np_imbalanced_tiled_along_split = np.tile(x.numpy(), reps) self.assertTrue( @@ -3294,6 +3294,16 @@ def test_tile(self): split = 1 x = ht.random.randn(4, 5, 3, 10, dtype=ht.float64, split=split) reps = (2, 2) + tiled_along_non_split = ht.tile(x, reps) + np_tiled_along_non_split = np.tile(x.numpy(), reps) + self.assertTrue((tiled_along_non_split.numpy() == np_tiled_along_non_split).all()) + self.assertTrue(tiled_along_non_split.dtype is x.dtype) + + # test tile along split axis + # len(reps) = x.ndim + split = 1 + x = ht.random.randn(3, 3, dtype=ht.float64, split=split) + reps = (2, 3) tiled_along_split = ht.tile(x, reps) np_tiled_along_split = np.tile(x.numpy(), reps) self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all()) @@ -3303,6 +3313,9 @@ def test_tile(self): float_reps = (1, 2, 2, 1.5) with self.assertRaises(TypeError): tiled_along_split = ht.tile(x, float_reps) + arraylike_float_reps = torch.tensor(float_reps) + with self.assertRaises(TypeError): + tiled_along_split = ht.tile(x, arraylike_float_reps) def test_topk(self): size = ht.MPI_WORLD.size From f412b0507815f07193be4e2bd03ac9a398f58f6e Mon Sep 17 00:00:00 2001 From: Claudia Comito Date: Wed, 11 Aug 2021 11:43:51 +0200 Subject: [PATCH 32/32] Remove unnecessary `else:` block --- heat/core/manipulations.py | 269 ++++++++++++++++++------------------- 1 file changed, 134 insertions(+), 135 deletions(-) diff --git a/heat/core/manipulations.py b/heat/core/manipulations.py index 5858330e49..7f54910918 100644 --- a/heat/core/manipulations.py +++ b/heat/core/manipulations.py @@ -3598,144 +3598,143 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray: comm=x.comm, balanced=x.balanced, ) - else: - # repeats along the split axis, work along dim 0 - size = x.comm.Get_size() - rank = x.comm.Get_rank() - trans_axes = list(range(x.ndim)) - if x.split != 0: - trans_axes[0], trans_axes[x.split] = x.split, 0 - reps[0], reps[x.split] = reps[x.split], reps[0] - x = linalg.transpose(x, trans_axes) - x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim) - out_gshape = tuple(x_proxy.repeat(reps).shape) - - local_x = x.larray - - # allocate tiled DNDarray, at first tiled along split axis only - split_reps = [rep if i == x.split else 1 for i, rep in enumerate(reps)] - split_tiled_shape = tuple(x_proxy.repeat(split_reps).shape) - tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=x.split, comm=x.comm) - # collect slicing information from all processes. - slices_map = [] - for array in [x, tiled]: - counts, displs = array.counts_displs() - t_slices_starts = torch.tensor(displs, device=local_x.device) - t_slices_ends = t_slices_starts + torch.tensor(counts, device=local_x.device) - slices_map.append([t_slices_starts, t_slices_ends]) - - t_slices_x, t_slices_tiled = slices_map - - # keep track of repetitions: - # local_x_starts.shape, local_x_ends.shape changing from (size,) to (reps[split], size) - reps_indices = list(x.gshape[x.split] * rep for rep in (range(reps[x.split]))) - t_reps_indices = torch.tensor( - reps_indices, dtype=torch.int32, device=local_x.device - ).reshape(len(reps_indices), 1) - for i, t in enumerate(t_slices_x): - t = t.repeat((reps[x.split], 1)) - t += t_reps_indices - t_slices_x[i] = t - - # distribution logic on current rank: - distr_map = [] - slices_map = [] - for i in range(2): - if i == 0: - # send logic for x slices on rank - local_x_starts = t_slices_x[0][:, rank].reshape(reps[x.split], 1) - local_x_ends = t_slices_x[1][:, rank].reshape(reps[x.split], 1) - t_tiled_starts, t_tiled_ends = t_slices_tiled - else: - # recv logic for tiled slices on rank - local_x_starts, local_x_ends = t_slices_x - t_tiled_starts = t_slices_tiled[0][rank] - t_tiled_ends = t_slices_tiled[1][rank] - t_max_starts = torch.max(local_x_starts, t_tiled_starts) - t_min_ends = torch.min(local_x_ends, t_tiled_ends) - coords = torch.where(t_min_ends - t_max_starts > 0) - # remove repeat offset from slices if sending - if i == 0: - t_max_starts -= t_reps_indices - t_min_ends -= t_reps_indices - starts = t_max_starts[coords].unsqueeze_(0) - ends = t_min_ends[coords].unsqueeze_(0) - slices_map.append(torch.cat((starts, ends), dim=0)) - distr_map.append(coords) - - # bookkeeping in preparation for Alltoallv - send_map, recv_map = distr_map - send_rep, send_to_ranks = send_map - recv_rep, recv_from_ranks = recv_map - send_slices, recv_slices = slices_map - - # do not assume that `x` is balanced - _, displs = x.counts_displs() - offset_x = displs[rank] - # impose load-balance on output - offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split) - t_tiled = tiled.larray - - active_send_counts = send_slices.clone() - active_send_counts[0] *= -1 - active_send_counts = active_send_counts.sum(0) - active_recv_counts = recv_slices.clone() - active_recv_counts[0] *= -1 - active_recv_counts = active_recv_counts.sum(0) - send_slices -= offset_x - recv_slices -= offset_tiled - recv_buf = t_tiled.clone() - # we need as many Alltoallv calls as repeats along the split axis - for rep in range(reps[x.split]): - # send_data, send_counts, send_displs on rank - all_send_counts = [0] * size - all_send_displs = [0] * size - send_this_rep = torch.where(send_rep == rep)[0].tolist() - dest_this_rep = send_to_ranks[send_this_rep].tolist() - for i, j in zip(send_this_rep, dest_this_rep): - all_send_counts[j] = active_send_counts[i].item() - all_send_displs[j] = send_slices[0][i].item() - local_send_slice = [slice(None)] * x.ndim - local_send_slice[x.split] = slice( - all_send_displs[0], all_send_displs[0] + sum(all_send_counts) - ) - send_buf = local_x[local_send_slice].clone() - - # recv_data, recv_counts, recv_displs on rank - all_recv_counts = [0] * size - all_recv_displs = [0] * size - recv_this_rep = torch.where(recv_rep == rep)[0].tolist() - orig_this_rep = recv_from_ranks[recv_this_rep].tolist() - for i, j in zip(recv_this_rep, orig_this_rep): - all_recv_counts[j] = active_recv_counts[i].item() - all_recv_displs[j] = recv_slices[0][i].item() - local_recv_slice = [slice(None)] * x.ndim - local_recv_slice[x.split] = slice( - all_recv_displs[0], all_recv_displs[0] + sum(all_recv_counts) - ) - x.comm.Alltoallv( - (send_buf, all_send_counts, all_send_displs), - (recv_buf, all_recv_counts, all_recv_displs), - ) - t_tiled[local_recv_slice] = recv_buf[local_recv_slice] + # repeats along the split axis, work along dim 0 + size = x.comm.Get_size() + rank = x.comm.Get_rank() + trans_axes = list(range(x.ndim)) + if x.split != 0: + trans_axes[0], trans_axes[x.split] = x.split, 0 + reps[0], reps[x.split] = reps[x.split], reps[0] + x = linalg.transpose(x, trans_axes) + x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim) + out_gshape = tuple(x_proxy.repeat(reps).shape) - # finally tile along non-split axes if needed - reps[x.split] = 1 - tiled = DNDarray( - t_tiled.repeat(reps), - out_gshape, - dtype=x.dtype, - split=x.split, - device=x.device, - comm=x.comm, - balanced=True, + local_x = x.larray + + # allocate tiled DNDarray, at first tiled along split axis only + split_reps = [rep if i == x.split else 1 for i, rep in enumerate(reps)] + split_tiled_shape = tuple(x_proxy.repeat(split_reps).shape) + tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=x.split, comm=x.comm) + # collect slicing information from all processes. + slices_map = [] + for array in [x, tiled]: + counts, displs = array.counts_displs() + t_slices_starts = torch.tensor(displs, device=local_x.device) + t_slices_ends = t_slices_starts + torch.tensor(counts, device=local_x.device) + slices_map.append([t_slices_starts, t_slices_ends]) + + t_slices_x, t_slices_tiled = slices_map + + # keep track of repetitions: + # local_x_starts.shape, local_x_ends.shape changing from (size,) to (reps[split], size) + reps_indices = list(x.gshape[x.split] * rep for rep in (range(reps[x.split]))) + t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32, device=local_x.device).reshape( + len(reps_indices), 1 + ) + for i, t in enumerate(t_slices_x): + t = t.repeat((reps[x.split], 1)) + t += t_reps_indices + t_slices_x[i] = t + + # distribution logic on current rank: + distr_map = [] + slices_map = [] + for i in range(2): + if i == 0: + # send logic for x slices on rank + local_x_starts = t_slices_x[0][:, rank].reshape(reps[x.split], 1) + local_x_ends = t_slices_x[1][:, rank].reshape(reps[x.split], 1) + t_tiled_starts, t_tiled_ends = t_slices_tiled + else: + # recv logic for tiled slices on rank + local_x_starts, local_x_ends = t_slices_x + t_tiled_starts = t_slices_tiled[0][rank] + t_tiled_ends = t_slices_tiled[1][rank] + t_max_starts = torch.max(local_x_starts, t_tiled_starts) + t_min_ends = torch.min(local_x_ends, t_tiled_ends) + coords = torch.where(t_min_ends - t_max_starts > 0) + # remove repeat offset from slices if sending + if i == 0: + t_max_starts -= t_reps_indices + t_min_ends -= t_reps_indices + starts = t_max_starts[coords].unsqueeze_(0) + ends = t_min_ends[coords].unsqueeze_(0) + slices_map.append(torch.cat((starts, ends), dim=0)) + distr_map.append(coords) + + # bookkeeping in preparation for Alltoallv + send_map, recv_map = distr_map + send_rep, send_to_ranks = send_map + recv_rep, recv_from_ranks = recv_map + send_slices, recv_slices = slices_map + + # do not assume that `x` is balanced + _, displs = x.counts_displs() + offset_x = displs[rank] + # impose load-balance on output + offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split) + t_tiled = tiled.larray + + active_send_counts = send_slices.clone() + active_send_counts[0] *= -1 + active_send_counts = active_send_counts.sum(0) + active_recv_counts = recv_slices.clone() + active_recv_counts[0] *= -1 + active_recv_counts = active_recv_counts.sum(0) + send_slices -= offset_x + recv_slices -= offset_tiled + recv_buf = t_tiled.clone() + # we need as many Alltoallv calls as repeats along the split axis + for rep in range(reps[x.split]): + # send_data, send_counts, send_displs on rank + all_send_counts = [0] * size + all_send_displs = [0] * size + send_this_rep = torch.where(send_rep == rep)[0].tolist() + dest_this_rep = send_to_ranks[send_this_rep].tolist() + for i, j in zip(send_this_rep, dest_this_rep): + all_send_counts[j] = active_send_counts[i].item() + all_send_displs[j] = send_slices[0][i].item() + local_send_slice = [slice(None)] * x.ndim + local_send_slice[x.split] = slice( + all_send_displs[0], all_send_displs[0] + sum(all_send_counts) + ) + send_buf = local_x[local_send_slice].clone() + + # recv_data, recv_counts, recv_displs on rank + all_recv_counts = [0] * size + all_recv_displs = [0] * size + recv_this_rep = torch.where(recv_rep == rep)[0].tolist() + orig_this_rep = recv_from_ranks[recv_this_rep].tolist() + for i, j in zip(recv_this_rep, orig_this_rep): + all_recv_counts[j] = active_recv_counts[i].item() + all_recv_displs[j] = recv_slices[0][i].item() + local_recv_slice = [slice(None)] * x.ndim + local_recv_slice[x.split] = slice( + all_recv_displs[0], all_recv_displs[0] + sum(all_recv_counts) ) - if trans_axes != list(range(x.ndim)): - # transpose back to original shape - x = linalg.transpose(x, trans_axes) - tiled = linalg.transpose(tiled, trans_axes) + x.comm.Alltoallv( + (send_buf, all_send_counts, all_send_displs), + (recv_buf, all_recv_counts, all_recv_displs), + ) + t_tiled[local_recv_slice] = recv_buf[local_recv_slice] + + # finally tile along non-split axes if needed + reps[x.split] = 1 + tiled = DNDarray( + t_tiled.repeat(reps), + out_gshape, + dtype=x.dtype, + split=x.split, + device=x.device, + comm=x.comm, + balanced=True, + ) + if trans_axes != list(range(x.ndim)): + # transpose back to original shape + x = linalg.transpose(x, trans_axes) + tiled = linalg.transpose(tiled, trans_axes) - return tiled + return tiled def topk(