Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/178 tile #673

Merged
merged 41 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
06e5d8c
Implement tile() for local repeats
ClaudiaComito Sep 8, 2020
2952143
Fix output split in non_distributed tiling
ClaudiaComito Sep 8, 2020
f10af0c
Implement distributed tile()
ClaudiaComito Sep 11, 2020
ee7eb58
Implement tile() for local repeats
ClaudiaComito Sep 8, 2020
0cda9e6
Fix output split in non_distributed tiling
ClaudiaComito Sep 8, 2020
852eb2a
Implement distributed tile()
ClaudiaComito Sep 11, 2020
475b8c6
Implement distributed tiling along non-split axis if tiling along spl…
ClaudiaComito Sep 16, 2020
722bea6
Introduce sanitation module.
ClaudiaComito Sep 16, 2020
b817e8f
Merge branch 'features/178-tile' of github.com:helmholtz-analytics/he…
ClaudiaComito Sep 16, 2020
dc0083e
Implement tile case len(reps) < x.ndim
ClaudiaComito Sep 17, 2020
a703bec
Implement test_tile()
ClaudiaComito Sep 17, 2020
0d1d6ca
Check that reps is all integers, improve docs.
ClaudiaComito Sep 17, 2020
60cdd20
Increase test coverage
ClaudiaComito Sep 17, 2020
5ec8b80
Implement test_sanitation
ClaudiaComito Sep 17, 2020
ba93fd2
Expand sanitation docs
ClaudiaComito Sep 17, 2020
53320fd
Update changelog
ClaudiaComito Sep 17, 2020
4e4bbb4
Increase test coverage
ClaudiaComito Sep 17, 2020
205bf76
Set device for stand-alone torch tensor
ClaudiaComito Sep 17, 2020
5059726
Fix torch device
ClaudiaComito Sep 17, 2020
840bf09
Merge branch 'master' into features/178-tile
coquelin77 Sep 21, 2020
67802cd
black
coquelin77 Sep 21, 2020
f776d74
Merge branch 'master' into features/178-tile
ClaudiaComito Sep 22, 2020
7e4076b
Merge branch 'master' into features/178-tile
ClaudiaComito Jun 21, 2021
797b045
Merge branch 'master' into features/178-tile
ClaudiaComito Jul 8, 2021
39c8ec9
distributed ht.tile() overhaul via Alltoallv
ClaudiaComito Jul 9, 2021
dfdf9b3
Remove dead code
ClaudiaComito Jul 9, 2021
117db7c
Fix tests
ClaudiaComito Jul 9, 2021
6901f10
clean up
ClaudiaComito Jul 9, 2021
808b320
Cast out_gshape to tuple after torch multiplication
ClaudiaComito Jul 9, 2021
ef84199
Do not assume input is balanced
ClaudiaComito Jul 9, 2021
900c8cc
Test imbalanced input case
ClaudiaComito Jul 9, 2021
0c20652
operate on transposed data if original split is not 0
ClaudiaComito Jul 9, 2021
2531a7c
Test for non-zero split as well
ClaudiaComito Jul 9, 2021
9b8dfaf
Fix out_gshape (tuple of ints, not tuple of torch tensors)
ClaudiaComito Jul 9, 2021
91ad6ae
Merge branch 'master' into features/178-tile
ClaudiaComito Aug 10, 2021
5a4f3f7
Improve sanitation logic, fix bug occurring when `x.split != 0`
ClaudiaComito Aug 11, 2021
a692133
Expand tests for array-like reps
ClaudiaComito Aug 11, 2021
f412b05
Remove unnecessary `else:` block
ClaudiaComito Aug 11, 2021
35564bc
Merge branch 'master' into features/178-tile
ClaudiaComito Aug 20, 2021
05a2b70
Merge branch 'master' into features/178-tile
ClaudiaComito Aug 20, 2021
d4f450e
Merge branch 'master' into features/178-tile
coquelin77 Aug 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ Example on 2 processes:
- [#664](https://github.com/helmholtz-analytics/heat/pull/664) New feature / enhancement: distributed `random.random_sample`, `random.random`, `random.sample`, `random.ranf`, `random.random_integer`
- [#666](https://github.com/helmholtz-analytics/heat/pull/666) New feature: distributed prepend/append for `diff()`.
- [#667](https://github.com/helmholtz-analytics/heat/pull/667) Enhancement `reshape`: rename axis parameter
- [#678](https://github.com/helmholtz-analytics/heat/pull/678) New feature: distributed `tile`
- [#670](https://github.com/helmholtz-analytics/heat/pull/670) New Feature: `bincount()`
- [#674](https://github.com/helmholtz-analytics/heat/pull/674) New feature: `repeat`
- [#670](https://github.com/helmholtz-analytics/heat/pull/670) New Feature: distributed `bincount()`
- [#672](https://github.com/helmholtz-analytics/heat/pull/672) Bug / Enhancement: Remove `MPIRequest.wait()`, rewrite calls with capital letters. lower case `wait()` now falls back to the `mpi4py` function
Expand Down
257 changes: 257 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"squeeze",
"stack",
"swapaxes",
"tile",
"topk",
"unique",
"vsplit",
Expand Down Expand Up @@ -3596,6 +3597,262 @@ def vstack(arrays: Sequence[DNDarray, ...]) -> DNDarray:
return concatenate(arrays, axis=0)


def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray:
"""
Construct a new DNDarray by repeating 'x' the number of times given by 'reps'.

If 'reps' has length 'd', the result will have 'max(d, x.ndim)' dimensions:

- if 'x.ndim < d', 'x' is promoted to be d-dimensional by prepending new axes.
So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3)
for 3-D replication (if this is not the desired behavior, promote 'x' to d-dimensions
manually before calling this function);

- if 'x.ndim > d', 'reps' will replicate the last 'd' dimensions of 'x', i.e., if
'x.shape' is (2, 3, 4, 5), a 'reps' of (2, 2) will be expanded to (1, 1, 2, 2).

Parameters
----------
x : DNDarray
Input

reps : Sequence[ints,...]
Repetitions

Returns
-------
tiled : DNDarray
Split semantics: if `x` is distributed, the tiled data will be distributed along the
same dimension. Note that nominally `tiled.split != x.split` in the case where
`len(reps) > x.ndim`. See example below.

Examples
--------
>>> x = ht.arange(12).reshape((4,3)).resplit_(0)
>>> x
DNDarray([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]], dtype=ht.int32, device=cpu:0, split=0)
>>> reps = (1, 2, 2)
>>> tiled = ht.tile(x, reps)
>>> tiled
DNDarray([[[ 0, 1, 2, 0, 1, 2],
[ 3, 4, 5, 3, 4, 5],
[ 6, 7, 8, 6, 7, 8],
[ 9, 10, 11, 9, 10, 11],
[ 0, 1, 2, 0, 1, 2],
[ 3, 4, 5, 3, 4, 5],
[ 6, 7, 8, 6, 7, 8],
[ 9, 10, 11, 9, 10, 11]]], dtype=ht.int32, device=cpu:0, split=1)
"""
# x can be DNDarray or scalar
try:
_ = x.larray
except AttributeError:
try:
_ = x.shape
raise TypeError("Input can be a DNDarray or a scalar, is {}".format(type(x)))
except AttributeError:
x = factories.array(x).reshape(1)

x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim)

# torch-proof args/kwargs:
# torch `reps`: int or sequence of ints; numpy `reps`: can be array-like
try:
_ = x_proxy.repeat(reps)
except TypeError:
# `reps` is array-like or contains non-int elements
try:
reps = resplit(reps, None).tolist()
except AttributeError:
try:
reps = reps.tolist()
except AttributeError:
try:
_ = x_proxy.repeat(reps)
except TypeError:
raise TypeError(
"reps must be a sequence of ints, got {}".format(
list(type(i) for i in reps)
)
)
except RuntimeError:
pass
except RuntimeError:
pass

try:
reps = list(reps)
except TypeError:
# scalar to list
reps = [reps]

# torch reps vs. numpy reps: dimensions
if len(reps) != x.ndim:
added_dims = abs(len(reps) - x.ndim)
if len(reps) > x.ndim:
new_shape = added_dims * (1,) + x.gshape
new_split = None if x.split is None else x.split + added_dims
x = x.reshape(new_shape, new_split=new_split)
else:
reps = added_dims * [1] + reps

out_gshape = tuple(x_proxy.repeat(reps).shape)

if not x.is_distributed() or reps[x.split] == 1:
# no repeats along the split axis: local operation
t_tiled = x.larray.repeat(reps)
out_gshape = tuple(x_proxy.repeat(reps).shape)
return DNDarray(
t_tiled,
out_gshape,
dtype=x.dtype,
split=x.split,
device=x.device,
comm=x.comm,
balanced=x.balanced,
)
# repeats along the split axis, work along dim 0
size = x.comm.Get_size()
rank = x.comm.Get_rank()
trans_axes = list(range(x.ndim))
if x.split != 0:
trans_axes[0], trans_axes[x.split] = x.split, 0
reps[0], reps[x.split] = reps[x.split], reps[0]
x = linalg.transpose(x, trans_axes)
x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim)
out_gshape = tuple(x_proxy.repeat(reps).shape)

local_x = x.larray

# allocate tiled DNDarray, at first tiled along split axis only
split_reps = [rep if i == x.split else 1 for i, rep in enumerate(reps)]
split_tiled_shape = tuple(x_proxy.repeat(split_reps).shape)
tiled = factories.empty(split_tiled_shape, dtype=x.dtype, split=x.split, comm=x.comm)
# collect slicing information from all processes.
slices_map = []
for array in [x, tiled]:
counts, displs = array.counts_displs()
t_slices_starts = torch.tensor(displs, device=local_x.device)
t_slices_ends = t_slices_starts + torch.tensor(counts, device=local_x.device)
slices_map.append([t_slices_starts, t_slices_ends])

t_slices_x, t_slices_tiled = slices_map

# keep track of repetitions:
# local_x_starts.shape, local_x_ends.shape changing from (size,) to (reps[split], size)
reps_indices = list(x.gshape[x.split] * rep for rep in (range(reps[x.split])))
t_reps_indices = torch.tensor(reps_indices, dtype=torch.int32, device=local_x.device).reshape(
len(reps_indices), 1
)
for i, t in enumerate(t_slices_x):
t = t.repeat((reps[x.split], 1))
t += t_reps_indices
t_slices_x[i] = t

# distribution logic on current rank:
distr_map = []
slices_map = []
for i in range(2):
if i == 0:
# send logic for x slices on rank
local_x_starts = t_slices_x[0][:, rank].reshape(reps[x.split], 1)
local_x_ends = t_slices_x[1][:, rank].reshape(reps[x.split], 1)
t_tiled_starts, t_tiled_ends = t_slices_tiled
else:
# recv logic for tiled slices on rank
local_x_starts, local_x_ends = t_slices_x
t_tiled_starts = t_slices_tiled[0][rank]
t_tiled_ends = t_slices_tiled[1][rank]
t_max_starts = torch.max(local_x_starts, t_tiled_starts)
t_min_ends = torch.min(local_x_ends, t_tiled_ends)
coords = torch.where(t_min_ends - t_max_starts > 0)
# remove repeat offset from slices if sending
if i == 0:
t_max_starts -= t_reps_indices
t_min_ends -= t_reps_indices
starts = t_max_starts[coords].unsqueeze_(0)
ends = t_min_ends[coords].unsqueeze_(0)
slices_map.append(torch.cat((starts, ends), dim=0))
distr_map.append(coords)

# bookkeeping in preparation for Alltoallv
send_map, recv_map = distr_map
send_rep, send_to_ranks = send_map
recv_rep, recv_from_ranks = recv_map
send_slices, recv_slices = slices_map

# do not assume that `x` is balanced
_, displs = x.counts_displs()
offset_x = displs[rank]
# impose load-balance on output
offset_tiled, _, _ = tiled.comm.chunk(tiled.gshape, tiled.split)
t_tiled = tiled.larray

active_send_counts = send_slices.clone()
active_send_counts[0] *= -1
active_send_counts = active_send_counts.sum(0)
active_recv_counts = recv_slices.clone()
active_recv_counts[0] *= -1
active_recv_counts = active_recv_counts.sum(0)
send_slices -= offset_x
recv_slices -= offset_tiled
recv_buf = t_tiled.clone()
# we need as many Alltoallv calls as repeats along the split axis
for rep in range(reps[x.split]):
# send_data, send_counts, send_displs on rank
all_send_counts = [0] * size
all_send_displs = [0] * size
send_this_rep = torch.where(send_rep == rep)[0].tolist()
dest_this_rep = send_to_ranks[send_this_rep].tolist()
for i, j in zip(send_this_rep, dest_this_rep):
all_send_counts[j] = active_send_counts[i].item()
all_send_displs[j] = send_slices[0][i].item()
local_send_slice = [slice(None)] * x.ndim
local_send_slice[x.split] = slice(
all_send_displs[0], all_send_displs[0] + sum(all_send_counts)
)
send_buf = local_x[local_send_slice].clone()

# recv_data, recv_counts, recv_displs on rank
all_recv_counts = [0] * size
all_recv_displs = [0] * size
recv_this_rep = torch.where(recv_rep == rep)[0].tolist()
orig_this_rep = recv_from_ranks[recv_this_rep].tolist()
for i, j in zip(recv_this_rep, orig_this_rep):
all_recv_counts[j] = active_recv_counts[i].item()
all_recv_displs[j] = recv_slices[0][i].item()
local_recv_slice = [slice(None)] * x.ndim
local_recv_slice[x.split] = slice(
all_recv_displs[0], all_recv_displs[0] + sum(all_recv_counts)
)
x.comm.Alltoallv(
(send_buf, all_send_counts, all_send_displs),
(recv_buf, all_recv_counts, all_recv_displs),
)
t_tiled[local_recv_slice] = recv_buf[local_recv_slice]

# finally tile along non-split axes if needed
reps[x.split] = 1
tiled = DNDarray(
t_tiled.repeat(reps),
out_gshape,
dtype=x.dtype,
split=x.split,
device=x.device,
comm=x.comm,
balanced=True,
)
if trans_axes != list(range(x.ndim)):
# transpose back to original shape
x = linalg.transpose(x, trans_axes)
tiled = linalg.transpose(tiled, trans_axes)

return tiled


def topk(
a: DNDarray,
k: int,
Expand Down
77 changes: 77 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3267,6 +3267,83 @@ def test_swapaxes(self):
with self.assertRaises(TypeError):
ht.swapaxes(x, 4.9, "abc")

def test_tile(self):
# test local tile, tuple reps
x = ht.arange(12).reshape((4, 3))
reps = (2, 1)
ht_tiled = ht.tile(x, reps)
np_tiled = np.tile(x.numpy(), reps)
self.assertTrue((np_tiled == ht_tiled.numpy()).all())
self.assertTrue(ht_tiled.dtype is x.dtype)

# test scalar x
x = ht.array(9.0)
reps = (2, 1)
ht_tiled = ht.tile(x, reps)
np_tiled = np.tile(x.numpy(), reps)
self.assertTrue((np_tiled == ht_tiled.numpy()).all())
self.assertTrue(ht_tiled.dtype is x.dtype)

# test distributed tile along split axis
# len(reps) > x.ndim
split = 1
x = ht.random.randn(4, 3, split=split)
reps = ht.random.randint(2, 10, size=(4,))
tiled_along_split = ht.tile(x, reps)
np_tiled_along_split = np.tile(x.numpy(), reps.tolist())
self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all())
self.assertTrue(tiled_along_split.dtype is x.dtype)

# test distributed tile along non-zero split axis
# len(reps) > x.ndim
split = 0
x = ht.random.randn(4, 3, split=split)
reps = np.random.randint(2, 10, size=(4,))
tiled_along_split = ht.tile(x, reps)
np_tiled_along_split = np.tile(x.numpy(), reps)
self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all())
self.assertTrue(tiled_along_split.dtype is x.dtype)

# test distributed tile() on imbalanced DNDarray
x = ht.random.randn(100, split=0)
x = x[ht.where(x > 0)]
reps = 5
imbalanced_tiled_along_split = ht.tile(x, reps)
np_imbalanced_tiled_along_split = np.tile(x.numpy(), reps)
self.assertTrue(
(imbalanced_tiled_along_split.numpy() == np_imbalanced_tiled_along_split).all()
)
self.assertTrue(imbalanced_tiled_along_split.dtype is x.dtype)
self.assertTrue(imbalanced_tiled_along_split.is_balanced(force_check=True))

# test tile along non-split axis
# len(reps) < x.ndim
split = 1
x = ht.random.randn(4, 5, 3, 10, dtype=ht.float64, split=split)
reps = (2, 2)
tiled_along_non_split = ht.tile(x, reps)
np_tiled_along_non_split = np.tile(x.numpy(), reps)
self.assertTrue((tiled_along_non_split.numpy() == np_tiled_along_non_split).all())
self.assertTrue(tiled_along_non_split.dtype is x.dtype)

# test tile along split axis
# len(reps) = x.ndim
split = 1
x = ht.random.randn(3, 3, dtype=ht.float64, split=split)
reps = (2, 3)
tiled_along_split = ht.tile(x, reps)
np_tiled_along_split = np.tile(x.numpy(), reps)
self.assertTrue((tiled_along_split.numpy() == np_tiled_along_split).all())
self.assertTrue(tiled_along_split.dtype is x.dtype)

# test exceptions
float_reps = (1, 2, 2, 1.5)
with self.assertRaises(TypeError):
tiled_along_split = ht.tile(x, float_reps)
arraylike_float_reps = torch.tensor(float_reps)
with self.assertRaises(TypeError):
tiled_along_split = ht.tile(x, arraylike_float_reps)

def test_topk(self):
size = ht.MPI_WORLD.size
if size == 1:
Expand Down