Skip to content

Commit

Permalink
Merge pull request #1075 from AsRaNi1/Communicator-not-properly-initi…
Browse files Browse the repository at this point in the history
…alized-when-creating-new-DNDarrays-in-some-routines/1074-my-bug-fix

Fixed initialization of DNDarrays communicator in some routines
  • Loading branch information
mrfh92 authored May 25, 2023
2 parents 468999e + a373be8 commit 5fd3af5
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 24 deletions.
6 changes: 5 additions & 1 deletion heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,11 @@ def __cum_op(
return out

return factories.array(
cumop, dtype=x.dtype if dtype is None else dtype, is_split=x.split, device=x.device
cumop,
dtype=x.dtype if dtype is None else dtype,
is_split=x.split,
device=x.device,
comm=x.comm,
)


Expand Down
39 changes: 23 additions & 16 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def cross(
a_2d = True
shape = tuple(1 if i == axisa else j for i, j in enumerate(a.shape))
a = manipulations.concatenate(
[a, factories.zeros(shape, dtype=a.dtype, device=a.device)], axis=axisa
[a, factories.zeros(shape, dtype=a.dtype, device=a.device, comm=a.comm)], axis=axisa
)
if b.shape[axisb] == 2:
b_2d = True
Expand Down Expand Up @@ -207,7 +207,7 @@ def det(a: DNDarray) -> DNDarray:

acopy = a.copy()
acopy = manipulations.reshape(acopy, (-1, m, m), new_split=a.split - a.ndim + 3)
adet = factories.ones(acopy.shape[0], dtype=a.dtype, device=a.device)
adet = factories.ones(acopy.shape[0], dtype=a.dtype, device=a.device, comm=a.comm)

for k in range(adet.shape[0]):
m = 0
Expand Down Expand Up @@ -475,6 +475,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
[1/1] tensor([[3., 1., 1., 1., 1., 1., 1.],
[4., 1., 1., 1., 1., 1., 1.]])
>>> linalg.matmul(a, b).larray
[0/1] tensor([[18., 8., 9., 10.],
[14., 6., 7., 8.],
[18., 7., 8., 9.],
Expand Down Expand Up @@ -520,7 +521,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
if a.split is None and b.split is None: # matmul from torch
if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit:
# if either of A or B is a vector
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device)
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device, comm=a.comm)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -529,7 +530,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a.larray @ b.larray

c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device)
c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device, comm=a.comm)
c.larray[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
if gpu_int_flag:
Expand All @@ -544,7 +545,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
b.resplit_(0)
res = a.larray @ b.larray
a.comm.Allreduce(MPI.IN_PLACE, res, MPI.SUM)
ret = factories.array(res, split=None, device=a.device)
ret = factories.array(res, split=None, device=a.device, comm=a.comm)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -567,7 +568,9 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
) and not vector_flag:
split = a.split if a.split is not None else b.split
split = split if not vector_flag else 0
c = factories.zeros((a.gshape[-2], b.gshape[1]), split=split, dtype=c_type, device=a.device)
c = factories.zeros(
(a.gshape[-2], b.gshape[1]), split=split, dtype=c_type, device=a.device, comm=a.comm
)
c.larray += a.larray @ b.larray

ret = c if not vector_flag else c.squeeze()
Expand All @@ -582,7 +585,9 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c += a.larray @ b.larray[a_idx[1].start : a_idx[1].start + a.lshape[-1], :]
a.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
c = c if not vector_flag else c.squeeze()
ret = factories.array(c, split=a.split if b.gshape[1] > 1 else 0, device=a.device)
ret = factories.array(
c, split=a.split if b.gshape[1] > 1 else 0, device=a.device, comm=a.comm
)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -593,7 +598,9 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c += a.larray[:, b_idx[0].start : b_idx[0].start + b.lshape[0]] @ b.larray
b.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
c = c if not vector_flag else c.squeeze()
ret = factories.array(c, split=b.split if a.gshape[-2] > 1 else 0, device=a.device)
ret = factories.array(
c, split=b.split if a.gshape[-2] > 1 else 0, device=a.device, comm=a.comm
)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -608,7 +615,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c = c if not vector_flag else c.squeeze()
split = a.split if b.gshape[1] > 1 else 0
split = split if not vector_flag else 0
ret = factories.array(c, split=split, device=a.device)
ret = factories.array(c, split=split, device=a.device, comm=a.comm)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -619,7 +626,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c = c if not vector_flag else c.squeeze()
split = b.split if a.gshape[1] > 1 else 0
split = split if not vector_flag else 0
ret = factories.array(c, is_split=split, device=a.device)
ret = factories.array(c, is_split=split, device=a.device, comm=a.comm)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand Down Expand Up @@ -695,10 +702,10 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:

# for the communication scheme, the output array needs to be created
c_shape = (a.gshape[-2], b.gshape[1])
c = factories.zeros(c_shape, split=a.split, dtype=c_type, device=a.device)
c = factories.zeros(c_shape, split=a.split, dtype=c_type, device=a.device, comm=a.comm)

# get the index map for c
c_index_map = factories.zeros((c.comm.size, 2, 2), device=a.device)
c_index_map = factories.zeros((c.comm.size, 2, 2), device=a.device, comm=a.comm)
c_idx = c.comm.chunk(c.shape, c.split)[2]
c_index_map[c.comm.rank, 0, :] = (c_idx[0].start, c_idx[0].stop)
c_index_map[c.comm.rank, 1, :] = (c_idx[1].start, c_idx[1].stop)
Expand Down Expand Up @@ -919,7 +926,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
if c_loc.nelement() == 1:
c_loc = torch.tensor(c_loc, device=tdev)

c = factories.array(c_loc, is_split=0, device=a.device)
c = factories.array(c_loc, is_split=0, device=a.device, comm=a.comm)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
Expand Down Expand Up @@ -1023,7 +1030,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c.larray[:, : b_node_rem_s1.shape[1]] += a_rem @ b_node_rem_s1
del a_lp_data[pr]
if vector_flag:
c = factories.array(c.larray.squeeze(), is_split=0, device=a.device)
c = factories.array(c.larray.squeeze(), is_split=0, device=a.device, comm=a.comm)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
Expand Down Expand Up @@ -1066,7 +1073,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c.larray[: sp0 - st0, st1:sp1] += a.larray @ b_lp_data[pr]
del b_lp_data[pr]
if vector_flag:
c = factories.array(c.larray.squeeze(), is_split=0, device=a.device)
c = factories.array(c.larray.squeeze(), is_split=0, device=a.device, comm=a.comm)
if gpu_int_flag:
c = og_type(c, device=a.device)

Expand All @@ -1090,7 +1097,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
if vector_flag:
split = 0
res = res.squeeze()
c = factories.array(res, split=split, device=a.device)
c = factories.array(res, split=split, device=a.device, comm=a.comm)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
Expand Down
4 changes: 2 additions & 2 deletions heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def qr(
except AttributeError:
q, r = a.larray.qr(some=False)

q = factories.array(q, device=a.device)
r = factories.array(r, device=a.device)
q = factories.array(q, device=a.device, comm=a.comm)
r = factories.array(r, device=a.device, comm=a.comm)
ret = QR(q if calc_q else None, r)
return ret
# =============================== Prep work ====================================================
Expand Down
15 changes: 10 additions & 5 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3085,10 +3085,13 @@ def unique(
)
if isinstance(torch_output, tuple):
heat_output = tuple(
factories.array(i, dtype=a.dtype, split=None, device=a.device) for i in torch_output
factories.array(i, dtype=a.dtype, split=None, device=a.device, comm=a.comm)
for i in torch_output
)
else:
heat_output = factories.array(torch_output, dtype=a.dtype, split=None, device=a.device)
heat_output = factories.array(
torch_output, dtype=a.dtype, split=None, device=a.device, comm=a.comm
)
return heat_output

local_data = a.larray
Expand Down Expand Up @@ -3366,7 +3369,7 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
if axis == arr.split:
return arr.copy()
if not arr.is_distributed():
return factories.array(arr.larray, split=axis, device=arr.device, copy=True)
return factories.array(arr.larray, split=axis, device=arr.device, comm=arr.comm, copy=True)

if axis is None:
# new_arr = arr.copy()
Expand All @@ -3375,7 +3378,9 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
)
counts, displs = arr.counts_displs()
arr.comm.Allgatherv(arr.larray, (gathered, counts, displs), recv_axis=arr.split)
new_arr = factories.array(gathered, is_split=axis, device=arr.device, dtype=arr.dtype)
new_arr = factories.array(
gathered, is_split=axis, device=arr.device, comm=arr.comm, dtype=arr.dtype
)
return new_arr

arr_tiles = tiling.SplitTiles(arr)
Expand Down Expand Up @@ -3954,7 +3959,7 @@ def local_topk(*args, **kwargs):
gres, dtype=a.dtype, device=a.device, split=split, is_split=is_split
)
final_indices = factories.array(
gindices, dtype=types.int64, device=a.device, split=split, is_split=is_split
gindices, dtype=types.int64, device=a.device, comm=a.comm, split=split, is_split=is_split
)

if out is not None:
Expand Down
3 changes: 3 additions & 0 deletions heat/core/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def average(
torch.broadcast_tensors(cumwgt.larray, result.larray)[0],
is_split=result.split,
device=result.device,
comm=result.comm,
copy=False,
)
return (result, cumwgt)
Expand Down Expand Up @@ -1284,6 +1285,7 @@ def __moment_w_axis(
is_split=x.split if axis > x.split else x.split - 1,
dtype=x.dtype,
device=x.device,
comm=x.comm,
copy=False,
)
elif not isinstance(axis, (list, tuple, torch.Tensor)):
Expand Down Expand Up @@ -1328,6 +1330,7 @@ def __moment_w_axis(
function(x.larray, **kwargs),
is_split=x.split if x.split < len(output_shape) else len(output_shape) - 1,
device=x.device,
comm=x.comm,
copy=False,
)

Expand Down

0 comments on commit 5fd3af5

Please sign in to comment.