Skip to content

Commit

Permalink
Resolves issue #422, Added comm checks for concatenate, modified docu…
Browse files Browse the repository at this point in the history
…mentation to state exceptions, changed exception types to consistent Python usage
  • Loading branch information
Markus-Goetz committed Apr 3, 2020
1 parent 2cec7a2 commit dc8dccc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
58 changes: 47 additions & 11 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def concatenate(arrays, axis=0):
res: DNDarray
The concatenated DNDarray
Raises
------
RuntimeError
If the concatted DNDarray meta information, e.g. split or comm, does not match.
TypeError
If the passed parameters are not of correct type (see documentation above).
ValueError
If the number of passed arrays is less than two or their shapes do not match.
Examples
--------
>>> x = ht.zeros((3, 5), split=None)
Expand Down Expand Up @@ -93,62 +102,80 @@ def concatenate(arrays, axis=0):
"""
if not isinstance(arrays, (tuple, list)):
raise TypeError("arrays must be a list or a tuple")

# a single array cannot be concatenated
if len(arrays) < 2:
raise ValueError("concatenate requires 2 arrays")
# concatenate multiple arrays
elif len(arrays) > 2:
res = concatenate((arrays[0], arrays[1]), axis=axis)
for a in range(2, len(arrays)):
res = concatenate((res, arrays[a]), axis=axis)
return res

# unpack the arrays
arr0, arr1 = arrays[0], arrays[1]

# input sanitation
if not isinstance(arr0, dndarray.DNDarray) or not isinstance(arr1, dndarray.DNDarray):
raise TypeError("Both arrays must be DNDarrays")

if not isinstance(axis, int):
raise TypeError("axis must be an integer, currently: {}".format(type(axis)))

axis = stride_tricks.sanitize_axis(arr0.gshape, axis)

if arr0.numdims != arr1.numdims:
raise RuntimeError("DNDarrays must have the same number of dimensions")
raise ValueError("DNDarrays must have the same number of dimensions")

if not all([arr0.gshape[i] == arr1.gshape[i] for i in range(len(arr0.gshape)) if i != axis]):
raise ValueError(
"Arrays cannot be concatenated, gshapes must be the same in every axis "
"Arrays cannot be concatenated, shapes must be the same in every axis "
"except the selected axis: {}, {}".format(arr0.gshape, arr1.gshape)
)

s0, s1 = arr0.split, arr1.split
# different communicators may not be concatenated
if arr0.comm != arr1.comm:
raise RuntimeError("Communicators of passed arrays mismatch.")

# identify common data type
out_dtype = types.promote_types(arr0.dtype, arr1.dtype)
if arr0.dtype != out_dtype:
arr0 = out_dtype(arr0, device=arr0.device)
if arr1.dtype != out_dtype:
arr1 = out_dtype(arr1, device=arr1.device)

s0, s1 = arr0.split, arr1.split
# no splits, local concat
if s0 is None and s1 is None:
return factories.array(
torch.cat((arr0._DNDarray__array, arr1._DNDarray__array), dim=axis), device=arr0.device
torch.cat((arr0._DNDarray__array, arr1._DNDarray__array), dim=axis),
device=arr0.device,
comm=arr0.comm,
)

# non-matching splits when both arrays are split
elif s0 != s1 and all([s is not None for s in [s0, s1]]):
raise RuntimeError(
"DNDarrays given have differing numerical splits, arr0 {} arr1 {}".format(s0, s1)
"DNDarrays given have differing split axes, arr0 {} arr1 {}".format(s0, s1)
)

# unsplit and split array
elif (s0 is None and s1 != axis) or (s1 is None and s0 != axis):
out_shape = tuple(
arr1.gshape[x] if x != axis else arr0.gshape[x] + arr1.gshape[x]
for x in range(len(arr1.gshape))
)
out = factories.empty(out_shape, split=s1 if s1 is not None else s0, device=arr1.device)
out = factories.empty(
out_shape, split=s1 if s1 is not None else s0, device=arr1.device, comm=arr0.comm
)

_, _, arr0_slice = arr1.comm.chunk(arr0.shape, arr1.split)
_, _, arr1_slice = arr0.comm.chunk(arr1.shape, arr0.split)
out._DNDarray__array = torch.cat(
(arr0._DNDarray__array[arr0_slice], arr1._DNDarray__array[arr1_slice]), dim=axis
)
out._DNDarray__comm = arr0.comm

return out

elif s0 == s1 or any([s is None for s in [s0, s1]]):
Expand All @@ -163,7 +190,9 @@ def concatenate(arrays, axis=0):
out._DNDarray__array = torch.cat(
(arr0._DNDarray__array, arr1._DNDarray__array), dim=axis
)
out._DNDarray__comm = arr0.comm
return out

else:
arr0 = arr0.copy()
arr1 = arr1.copy()
Expand Down Expand Up @@ -281,8 +310,9 @@ def concatenate(arrays, axis=0):
arb_slice = [None] * len(arr1.shape)
for c in range(len(chunk_map)):
arb_slice[axis] = c
# the chunk map is adjusted by subtracting what data is already in the correct place (the data from arr1 is already correctly placed)
# i.e. the chunk map shows how much data is still needed on each process, the local
# the chunk map is adjusted by subtracting what data is already in the correct place (the data from
# arr1 is already correctly placed) i.e. the chunk map shows how much data is still needed on each
# process, the local
chunk_map[arb_slice] -= lshape_map[tuple([1] + arb_slice)]

# after adjusting arr1 need to now select the target data in arr0 on each node with a local slice
Expand Down Expand Up @@ -328,12 +358,18 @@ def concatenate(arrays, axis=0):
if len(arr1.lshape) < len(arr0.lshape):
arr1._DNDarray__array.unsqueeze_(axis)

# now that the data is in the proper shape, need to concatenate them on the nodes where they both exist for the others, just set them equal
# now that the data is in the proper shape, need to concatenate them on the nodes where they both exist for
# the others, just set them equal
out = factories.empty(
out_shape, split=s0 if s0 is not None else s1, dtype=out_dtype, device=arr0.device
out_shape,
split=s0 if s0 is not None else s1,
dtype=out_dtype,
device=arr0.device,
comm=arr0.comm,
)
res = torch.cat((arr0._DNDarray__array, arr1._DNDarray__array), dim=axis)
out._DNDarray__array = res

return out


Expand Down
6 changes: 5 additions & 1 deletion heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,12 @@ def test_concatenate(self):
ht.concatenate((x))
with self.assertRaises(TypeError):
ht.concatenate((x, x), axis=x)
with self.assertRaises(RuntimeError):
with self.assertRaises(ValueError):
ht.concatenate((x, ht.zeros((2, 2), device=ht_device)), axis=0)
with self.assertRaises(RuntimeError):
a = ht.zeros((10,), comm=ht.communication.MPI_WORLD)
b = ht.zeros((10,), comm=ht.communication.MPI_SELF)
ht.concatenate([a, b])
with self.assertRaises(ValueError):
ht.concatenate(
(ht.zeros((12, 12), device=ht_device), ht.zeros((2, 2), device=ht_device)), axis=0
Expand Down

0 comments on commit dc8dccc

Please sign in to comment.