Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed initialization of DNDarrays communicator in some routines #1075

Merged
Changes from 8 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
872ae30
fixed initialization of DNDarrays in some routines
AsRaNi1 Jan 15, 2023
93d0793
Merge branch 'helmholtz-analytics:main' into Communicator-not-properl…
AsRaNi1 Jan 16, 2023
3b3bfec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2023
9ed3f60
resolved improper init of DNDarrays in some routines
AsRaNi1 Jan 16, 2023
3acb3e3
resolved improper init of DNDarrays in some routines
AsRaNi1 Jan 16, 2023
6ff1d6f
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
AsRaNi1 Jan 30, 2023
cf40aff
resolved improper init of DNDarrays in some routines
AsRaNi1 Jan 30, 2023
099372c
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
AsRaNi1 Feb 4, 2023
5408dff
Merge branch 'helmholtz-analytics:main' into Communicator-not-properl…
AsRaNi1 Feb 13, 2023
552b10e
Merge branch 'helmholtz-analytics:main' into Communicator-not-properl…
AsRaNi1 Feb 15, 2023
9dbdd1b
rectified more places where comm was not specified
AsRaNi1 Feb 15, 2023
b88f88e
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
AsRaNi1 Mar 5, 2023
b69f6a0
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
AsRaNi1 Mar 16, 2023
1824fb5
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
AsRaNi1 Mar 18, 2023
b0af45a
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
ClaudiaComito Mar 20, 2023
b9bade9
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
ClaudiaComito Mar 29, 2023
1ba3d38
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
ClaudiaComito Apr 17, 2023
c850f87
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
ClaudiaComito May 22, 2023
c858d6d
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
mrfh92 May 22, 2023
045cb5c
Update manipulations.py
mrfh92 May 22, 2023
9fbfffd
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
mrfh92 May 22, 2023
6bf29c4
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
mrfh92 May 23, 2023
a373be8
Merge branch 'main' into Communicator-not-properly-initialized-when-c…
mrfh92 May 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def cross(
a_2d = True
shape = tuple(1 if i == axisa else j for i, j in enumerate(a.shape))
a = manipulations.concatenate(
[a, factories.zeros(shape, dtype=a.dtype, device=a.device)], axis=axisa
[a, factories.zeros(shape, dtype=a.dtype, device=a.device, comm=a.comm)], axis=axisa
)
if b.shape[axisb] == 2:
b_2d = True
Expand Down Expand Up @@ -207,7 +207,7 @@ def det(a: DNDarray) -> DNDarray:

acopy = a.copy()
acopy = manipulations.reshape(acopy, (-1, m, m), new_split=a.split - a.ndim + 3)
adet = factories.ones(acopy.shape[0], dtype=a.dtype, device=a.device)
adet = factories.ones(acopy.shape[0], dtype=a.dtype, device=a.device, comm=a.comm)

for k in range(adet.shape[0]):
m = 0
Expand Down Expand Up @@ -475,6 +475,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
[1/1] tensor([[3., 1., 1., 1., 1., 1., 1.],
[4., 1., 1., 1., 1., 1., 1.]])
>>> linalg.matmul(a, b).larray

[0/1] tensor([[18., 8., 9., 10.],
[14., 6., 7., 8.],
[18., 7., 8., 9.],
Expand Down Expand Up @@ -520,7 +521,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
if a.split is None and b.split is None: # matmul from torch
if len(a.gshape) < 2 or len(b.gshape) < 2 or not allow_resplit:
# if either of A or B is a vector
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device)
ret = factories.array(torch.matmul(a.larray, b.larray), device=a.device, comm=a.comm)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -529,7 +530,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
slice_0 = a.comm.chunk(a.shape, a.split)[2][0]
hold = a.larray @ b.larray

c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device)
c = factories.zeros((a.gshape[-2], b.gshape[1]), dtype=c_type, device=a.device, comm=a.comm)
c.larray[slice_0.start : slice_0.stop, :] += hold
c.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
if gpu_int_flag:
Expand All @@ -544,7 +545,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
b.resplit_(0)
res = a.larray @ b.larray
a.comm.Allreduce(MPI.IN_PLACE, res, MPI.SUM)
ret = factories.array(res, split=None, device=a.device)
ret = factories.array(res, split=None, device=a.device, comm=a.comm)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -567,7 +568,9 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
) and not vector_flag:
split = a.split if a.split is not None else b.split
split = split if not vector_flag else 0
c = factories.zeros((a.gshape[-2], b.gshape[1]), split=split, dtype=c_type, device=a.device)
c = factories.zeros(
(a.gshape[-2], b.gshape[1]), split=split, dtype=c_type, device=a.device, comm=a.comm
)
c.larray += a.larray @ b.larray

ret = c if not vector_flag else c.squeeze()
Expand All @@ -582,7 +585,9 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c += a.larray @ b.larray[a_idx[1].start : a_idx[1].start + a.lshape[-1], :]
a.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
c = c if not vector_flag else c.squeeze()
ret = factories.array(c, split=a.split if b.gshape[1] > 1 else 0, device=a.device)
ret = factories.array(
c, split=a.split if b.gshape[1] > 1 else 0, device=a.device, comm=a.comm
)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -593,7 +598,9 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c += a.larray[:, b_idx[0].start : b_idx[0].start + b.lshape[0]] @ b.larray
b.comm.Allreduce(MPI.IN_PLACE, c, MPI.SUM)
c = c if not vector_flag else c.squeeze()
ret = factories.array(c, split=b.split if a.gshape[-2] > 1 else 0, device=a.device)
ret = factories.array(
c, split=b.split if a.gshape[-2] > 1 else 0, device=a.device, comm=a.comm
)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -608,7 +615,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c = c if not vector_flag else c.squeeze()
split = a.split if b.gshape[1] > 1 else 0
split = split if not vector_flag else 0
ret = factories.array(c, split=split, device=a.device)
ret = factories.array(c, split=split, device=a.device, comm=a.comm)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand All @@ -619,7 +626,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c = c if not vector_flag else c.squeeze()
split = b.split if a.gshape[1] > 1 else 0
split = split if not vector_flag else 0
ret = factories.array(c, is_split=split, device=a.device)
ret = factories.array(c, is_split=split, device=a.device, comm=a.comm)
if gpu_int_flag:
ret = og_type(ret, device=a.device)
return ret
Expand Down Expand Up @@ -695,10 +702,10 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:

# for the communication scheme, the output array needs to be created
c_shape = (a.gshape[-2], b.gshape[1])
c = factories.zeros(c_shape, split=a.split, dtype=c_type, device=a.device)
c = factories.zeros(c_shape, split=a.split, dtype=c_type, device=a.device, comm=a.comm)

# get the index map for c
c_index_map = factories.zeros((c.comm.size, 2, 2), device=a.device)
c_index_map = factories.zeros((c.comm.size, 2, 2), device=a.device, comm=a.comm)
c_idx = c.comm.chunk(c.shape, c.split)[2]
c_index_map[c.comm.rank, 0, :] = (c_idx[0].start, c_idx[0].stop)
c_index_map[c.comm.rank, 1, :] = (c_idx[1].start, c_idx[1].stop)
Expand Down Expand Up @@ -919,7 +926,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
if c_loc.nelement() == 1:
c_loc = torch.tensor(c_loc, device=tdev)

c = factories.array(c_loc, is_split=0, device=a.device)
c = factories.array(c_loc, is_split=0, device=a.device, comm=a.comm)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
Expand Down Expand Up @@ -1023,7 +1030,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c.larray[:, : b_node_rem_s1.shape[1]] += a_rem @ b_node_rem_s1
del a_lp_data[pr]
if vector_flag:
c = factories.array(c.larray.squeeze(), is_split=0, device=a.device)
c = factories.array(c.larray.squeeze(), is_split=0, device=a.device, comm=a.comm)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
Expand Down Expand Up @@ -1066,7 +1073,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
c.larray[: sp0 - st0, st1:sp1] += a.larray @ b_lp_data[pr]
del b_lp_data[pr]
if vector_flag:
c = factories.array(c.larray.squeeze(), is_split=0, device=a.device)
c = factories.array(c.larray.squeeze(), is_split=0, device=a.device, comm=a.comm)
if gpu_int_flag:
c = og_type(c, device=a.device)

Expand All @@ -1090,7 +1097,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
if vector_flag:
split = 0
res = res.squeeze()
c = factories.array(res, split=split, device=a.device)
c = factories.array(res, split=split, device=a.device, comm=a.comm)
if gpu_int_flag:
c = og_type(c, device=a.device)
return c
Expand Down