Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add solver for distributed triangular systems #1236

Merged
merged 36 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8ffa51c
added my first ideas for an implementation of solve_triangular (not a…
Oct 10, 2023
8003642
readme test
FOsterfeld Oct 10, 2023
e6c88e8
dummy commit to find out whether CI works...
mrfh92 Oct 12, 2023
782c7b2
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Dec 18, 2023
524e294
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Jan 3, 2024
3749dbf
small improvements in solve_triangular for the case A.split == 1
FOsterfeld Jan 16, 2024
de2fe73
small improvements in solve_triangular for the case A.split == 0
FOsterfeld Jan 17, 2024
85fd3af
added tests for solve_triangular
FOsterfeld Jan 17, 2024
bcac92e
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Jan 22, 2024
aa27c2c
implemented the case A.split == 1 with scatterv instead of bcast to r…
FOsterfeld Feb 6, 2024
6fe7ed0
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Feb 6, 2024
7accc48
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Feb 20, 2024
7f2ccc5
implemented solve_triangular for batched inputs and b.split = None
FOsterfeld Feb 20, 2024
7aa3f8d
Merge branch 'features/1096-Provide_a_solver_for_triangular_systems' …
FOsterfeld Feb 20, 2024
8d2af3e
implemented solve_triangular for A.split = None, b.split = -2
FOsterfeld Feb 27, 2024
fd7e2d7
added tests for batched input of solve_triangular
FOsterfeld Feb 27, 2024
c8b0890
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Feb 27, 2024
4591260
fixed a bug when not using the cpu
FOsterfeld Feb 29, 2024
c62434e
Merge branch 'features/1096-Provide_a_solver_for_triangular_systems' …
FOsterfeld Feb 29, 2024
25bbabd
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
FOsterfeld Feb 29, 2024
1bcd97f
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
FOsterfeld Feb 29, 2024
cf84b88
specify device
FOsterfeld Feb 29, 2024
18d3f00
fixed on gpu
FOsterfeld Mar 5, 2024
4f3cd08
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
FOsterfeld Mar 5, 2024
bac8fc7
improved exception handling for invalid input data
FOsterfeld Mar 26, 2024
22507ae
added tests for higher dimensional batches
FOsterfeld Mar 26, 2024
96fa81f
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Apr 3, 2024
bfac0c3
changed error message for non-equal local sizes along split axis
Apr 3, 2024
db370fe
updated docstring according to our conventions
Apr 3, 2024
f790e23
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
ClaudiaComito Apr 11, 2024
af10937
Update heat/core/linalg/solver.py
mrfh92 Apr 11, 2024
795053f
Update heat/core/linalg/solver.py
mrfh92 Apr 11, 2024
ba3578a
Update heat/core/linalg/solver.py
mrfh92 Apr 11, 2024
e248a42
Update solver.py
mrfh92 Apr 11, 2024
d9edee7
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Apr 11, 2024
6589754
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 186 additions & 1 deletion heat/core/linalg/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from ..dndarray import DNDarray
from ..sanitation import sanitize_out
from typing import List, Dict, Any, TypeVar, Union, Tuple, Optional
from .. import factories

import torch

__all__ = ["cg", "lanczos"]
__all__ = ["cg", "lanczos", "solve_triangular"]


def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
Expand Down Expand Up @@ -269,3 +270,187 @@ def lanczos(
V.resplit_(axis=None)

return V, T


def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray:
"""
Solve upper triangular systems of linear equations.

Input:
A - an upper triangular (possibly batched) invertible square (n x n) matrix
b - (possibly batched) n x k matrix

Output:
The unique solution x of A * x = b.

Vectors b have to be given as n x 1 matrices.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing, also not following our documentation scheme.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the docstring according to our conventions.

"""
if not isinstance(A, DNDarray) or not isinstance(b, DNDarray):
raise RuntimeError("Arguments need to be a DNDarrays.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise RuntimeError("Arguments need to be a DNDarrays.")
raise RuntimeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a TypeError

if not A.ndim >= 2:
raise RuntimeError("A needs to be a 2D matrix.")
if not b.ndim == A.ndim:
raise RuntimeError("b needs to be a 2D matrix.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These would be ValueErrors I think. The message is a bit confusing, because we do allow n-D arrays

if not A.shape[-2] == A.shape[-1]:
raise RuntimeError("A needs to be a square matrix.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ValueError, see https://docs.python.org/3/library/exceptions.html#

"... a square matrix or a batch of square matrices"


batch_dim = A.ndim - 2
batch_shape = A.shape[:batch_dim]

if not A.shape[:batch_dim] == b.shape[:batch_dim]:
raise RuntimeError("Batch dimensions of A and b must be of the same shape.")
if b.split == batch_dim + 1:
raise RuntimeError("split=1 is not allowed for the right hand side.")
if not b.shape[batch_dim] == A.shape[-1]:
raise RuntimeError("Dimension mismath of A and b.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All ValueErrors I think.


if (
A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim
): # batch split
if A.split != b.split:
raise RuntimeError(
"If a split dimension is a batch dimension, A and b must have the same split dimension."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Value Error

Let's help the users out here, suggest what they're suppose to change in practice.

)
else:
if (
A.split is not None and b.split is not None
): # both la dimensions split --> b.split = batch_dim
if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]):
raise RuntimeError("Local arrays of A and b have different sizes.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the user supposed to do about this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I not sure whether this error can even occur because at this point both dndarrays should have the same size in that axis. (Otherwise lines 294 or 304 should have thrown an exception.)
What do you think @mrfh92 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this can happen if at least one of the arrays does not follow the "standard" splitting scheme, e.g., if it is unbalanced for whatever reason.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the error message accordingly.


comm = A.comm
dev = A.device
tdev = dev.torch_device

nprocs = comm.Get_size()

if A.split is None: # A not split
if b.split is None:
x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)

return factories.array(x, dtype=b.dtype, device=dev, comm=comm)
else: # A not split, b.split == -2
b_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=tdev),
torch.cumsum(b.lshape_map[:, -2], 0),
]
)

btilde_loc = b.larray.clone()
A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]]

x = factories.zeros_like(b, device=dev, comm=comm)

for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0

res_send = torch.empty(0)
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)

if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A_loc @ x.larray

comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)

if comm.rank < i:
btilde_loc -= res_recv

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True
)

return x

if A.split < batch_dim: # batch split
x = factories.zeros_like(b, device=dev, comm=comm, split=A.split)
x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)

return x

if A.split >= batch_dim: # both splits in la dims
A_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=tdev),
torch.cumsum(A.lshape_map[:, A.split], 0),
]
)

if b.split is None:
btilde_loc = b.larray[
..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], :
].clone()
else: # b is split at la dim 0
btilde_loc = b.larray.clone()

x = factories.zeros_like(
b, device=dev, comm=comm, split=batch_dim
) # split at la dim 0 in case b is not split

if A.split == batch_dim + 1:
for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0

res_send = torch.empty(0)
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)

if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A.larray @ x.larray

comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)

if comm.rank < i:
btilde_loc -= res_recv

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True
)

else: # split dim is la dim 0
for i in range(nprocs - 1, 0, -1):
idims = tuple(x.lshape_map[i])
if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]],
btilde_loc,
upper=True,
)
x_from_i = x.larray
else:
x_from_i = torch.zeros(
idims,
dtype=b.dtype.torch_type(),
device=tdev,
)

comm.Bcast(x_from_i, root=i)

if comm.rank < i:
btilde_loc -= (
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i
)

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True
)

return x
77 changes: 77 additions & 0 deletions heat/core/linalg/tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,80 @@ def test_lanczos(self):
with self.assertRaises(NotImplementedError):
A = ht.random.randn(10, 10, split=1)
V, T = ht.lanczos(A, m=3)

def test_solve_triangular(self):
torch.manual_seed(42)
tdev = ht.get_device().torch_device

# non-batched tests
k = 100 # data dimension size

# random triangular matrix inversion
at = torch.rand((k, k))
# at += torch.eye(k)
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
at = torch.triu(at).to(tdev)

ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)
b = ht.eye(k)

for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)
self.assertTrue(ht.allclose(res, c))

# triangular ones inversion
# for this test case, the results should be exact
at = torch.triu(torch.ones_like(at)).to(tdev)
ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)

for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)
self.assertTrue(ht.equal(res, c))

# batched tests
batch_shape = (10,) # batch dimensions shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we test higher-dimensional batches as well, i.e. 6-dimensional with split=3?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea, I have done so with 22507ae.

# batch_shape = tuple() # no batch dimensions
m = 100 # data dimension size

at = torch.rand((*batch_shape, m, m))
# at += torch.eye(k)
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
at = torch.triu(at).to(tdev)
bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev)

ct = torch.linalg.solve_triangular(at, bt, upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)
b = ht.factories.asarray(bt, copy=True)

# split in linalg dimension or none
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)

self.assertTrue(ht.allclose(c, res))

# split in batch dimension
s = 0
a.resplit_(s)
b.resplit_(s)
c.resplit_(s)

res = ht.linalg.solve_triangular(a, b)

self.assertTrue(ht.allclose(c, res))