-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add solver for distributed triangular systems #1236
Changes from 24 commits
8ffa51c
8003642
e6c88e8
782c7b2
524e294
3749dbf
de2fe73
85fd3af
bcac92e
aa27c2c
6fe7ed0
7accc48
7f2ccc5
7aa3f8d
8d2af3e
fd7e2d7
c8b0890
4591260
c62434e
25bbabd
1bcd97f
cf84b88
18d3f00
4f3cd08
bac8fc7
22507ae
96fa81f
bfac0c3
db370fe
f790e23
af10937
795053f
ba3578a
e248a42
d9edee7
6589754
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -6,10 +6,11 @@ | |||||
from ..dndarray import DNDarray | ||||||
from ..sanitation import sanitize_out | ||||||
from typing import List, Dict, Any, TypeVar, Union, Tuple, Optional | ||||||
from .. import factories | ||||||
|
||||||
import torch | ||||||
|
||||||
__all__ = ["cg", "lanczos"] | ||||||
__all__ = ["cg", "lanczos", "solve_triangular"] | ||||||
|
||||||
|
||||||
def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: | ||||||
|
@@ -269,3 +270,187 @@ def lanczos( | |||||
V.resplit_(axis=None) | ||||||
|
||||||
return V, T | ||||||
|
||||||
|
||||||
def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: | ||||||
""" | ||||||
Solve upper triangular systems of linear equations. | ||||||
|
||||||
Input: | ||||||
A - an upper triangular (possibly batched) invertible square (n x n) matrix | ||||||
b - (possibly batched) n x k matrix | ||||||
|
||||||
Output: | ||||||
The unique solution x of A * x = b. | ||||||
|
||||||
Vectors b have to be given as n x 1 matrices. | ||||||
""" | ||||||
if not isinstance(A, DNDarray) or not isinstance(b, DNDarray): | ||||||
raise RuntimeError("Arguments need to be a DNDarrays.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be a TypeError |
||||||
if not A.ndim >= 2: | ||||||
raise RuntimeError("A needs to be a 2D matrix.") | ||||||
if not b.ndim == A.ndim: | ||||||
raise RuntimeError("b needs to be a 2D matrix.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These would be ValueErrors I think. The message is a bit confusing, because we do allow n-D arrays |
||||||
if not A.shape[-2] == A.shape[-1]: | ||||||
raise RuntimeError("A needs to be a square matrix.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ValueError, see https://docs.python.org/3/library/exceptions.html# "... a square matrix or a batch of square matrices" |
||||||
|
||||||
batch_dim = A.ndim - 2 | ||||||
batch_shape = A.shape[:batch_dim] | ||||||
|
||||||
if not A.shape[:batch_dim] == b.shape[:batch_dim]: | ||||||
raise RuntimeError("Batch dimensions of A and b must be of the same shape.") | ||||||
if b.split == batch_dim + 1: | ||||||
raise RuntimeError("split=1 is not allowed for the right hand side.") | ||||||
if not b.shape[batch_dim] == A.shape[-1]: | ||||||
raise RuntimeError("Dimension mismath of A and b.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All ValueErrors I think. |
||||||
|
||||||
if ( | ||||||
A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim | ||||||
): # batch split | ||||||
if A.split != b.split: | ||||||
raise RuntimeError( | ||||||
"If a split dimension is a batch dimension, A and b must have the same split dimension." | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Value Error Let's help the users out here, suggest what they're suppose to change in practice. |
||||||
) | ||||||
else: | ||||||
if ( | ||||||
A.split is not None and b.split is not None | ||||||
): # both la dimensions split --> b.split = batch_dim | ||||||
if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]): | ||||||
raise RuntimeError("Local arrays of A and b have different sizes.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the user supposed to do about this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I not sure whether this error can even occur because at this point both dndarrays should have the same size in that axis. (Otherwise lines 294 or 304 should have thrown an exception.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, this can happen if at least one of the arrays does not follow the "standard" splitting scheme, e.g., if it is unbalanced for whatever reason. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have updated the error message accordingly. |
||||||
|
||||||
comm = A.comm | ||||||
dev = A.device | ||||||
tdev = dev.torch_device | ||||||
|
||||||
nprocs = comm.Get_size() | ||||||
|
||||||
if A.split is None: # A not split | ||||||
if b.split is None: | ||||||
x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) | ||||||
|
||||||
return factories.array(x, dtype=b.dtype, device=dev, comm=comm) | ||||||
else: # A not split, b.split == -2 | ||||||
b_lshapes_cum = torch.hstack( | ||||||
[ | ||||||
torch.zeros(1, dtype=torch.int32, device=tdev), | ||||||
torch.cumsum(b.lshape_map[:, -2], 0), | ||||||
] | ||||||
) | ||||||
|
||||||
btilde_loc = b.larray.clone() | ||||||
A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]] | ||||||
|
||||||
x = factories.zeros_like(b, device=dev, comm=comm) | ||||||
|
||||||
for i in range(nprocs - 1, 0, -1): | ||||||
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy() | ||||||
displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy() | ||||||
count[i:] = 0 # nothing to send, as there are only zero rows | ||||||
displ[i:] = 0 | ||||||
|
||||||
res_send = torch.empty(0) | ||||||
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev) | ||||||
|
||||||
if comm.rank == i: | ||||||
x.larray = torch.linalg.solve_triangular( | ||||||
A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :], | ||||||
btilde_loc, | ||||||
upper=True, | ||||||
) | ||||||
res_send = A_loc @ x.larray | ||||||
|
||||||
comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim) | ||||||
|
||||||
if comm.rank < i: | ||||||
btilde_loc -= res_recv | ||||||
|
||||||
if comm.rank == 0: | ||||||
x.larray = torch.linalg.solve_triangular( | ||||||
A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True | ||||||
) | ||||||
|
||||||
return x | ||||||
|
||||||
if A.split < batch_dim: # batch split | ||||||
x = factories.zeros_like(b, device=dev, comm=comm, split=A.split) | ||||||
x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) | ||||||
|
||||||
return x | ||||||
|
||||||
if A.split >= batch_dim: # both splits in la dims | ||||||
A_lshapes_cum = torch.hstack( | ||||||
[ | ||||||
torch.zeros(1, dtype=torch.int32, device=tdev), | ||||||
torch.cumsum(A.lshape_map[:, A.split], 0), | ||||||
] | ||||||
) | ||||||
|
||||||
if b.split is None: | ||||||
btilde_loc = b.larray[ | ||||||
..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], : | ||||||
].clone() | ||||||
else: # b is split at la dim 0 | ||||||
btilde_loc = b.larray.clone() | ||||||
|
||||||
x = factories.zeros_like( | ||||||
b, device=dev, comm=comm, split=batch_dim | ||||||
) # split at la dim 0 in case b is not split | ||||||
|
||||||
if A.split == batch_dim + 1: | ||||||
for i in range(nprocs - 1, 0, -1): | ||||||
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy() | ||||||
displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy() | ||||||
count[i:] = 0 # nothing to send, as there are only zero rows | ||||||
displ[i:] = 0 | ||||||
|
||||||
res_send = torch.empty(0) | ||||||
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev) | ||||||
|
||||||
if comm.rank == i: | ||||||
x.larray = torch.linalg.solve_triangular( | ||||||
A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :], | ||||||
btilde_loc, | ||||||
upper=True, | ||||||
) | ||||||
res_send = A.larray @ x.larray | ||||||
|
||||||
comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim) | ||||||
|
||||||
if comm.rank < i: | ||||||
btilde_loc -= res_recv | ||||||
|
||||||
if comm.rank == 0: | ||||||
x.larray = torch.linalg.solve_triangular( | ||||||
A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True | ||||||
) | ||||||
|
||||||
else: # split dim is la dim 0 | ||||||
for i in range(nprocs - 1, 0, -1): | ||||||
idims = tuple(x.lshape_map[i]) | ||||||
if comm.rank == i: | ||||||
x.larray = torch.linalg.solve_triangular( | ||||||
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]], | ||||||
btilde_loc, | ||||||
upper=True, | ||||||
) | ||||||
x_from_i = x.larray | ||||||
else: | ||||||
x_from_i = torch.zeros( | ||||||
idims, | ||||||
dtype=b.dtype.torch_type(), | ||||||
device=tdev, | ||||||
) | ||||||
|
||||||
comm.Bcast(x_from_i, root=i) | ||||||
|
||||||
if comm.rank < i: | ||||||
btilde_loc -= ( | ||||||
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i | ||||||
) | ||||||
|
||||||
if comm.rank == 0: | ||||||
x.larray = torch.linalg.solve_triangular( | ||||||
A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True | ||||||
) | ||||||
|
||||||
return x |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,3 +135,80 @@ def test_lanczos(self): | |
with self.assertRaises(NotImplementedError): | ||
A = ht.random.randn(10, 10, split=1) | ||
V, T = ht.lanczos(A, m=3) | ||
|
||
def test_solve_triangular(self): | ||
torch.manual_seed(42) | ||
tdev = ht.get_device().torch_device | ||
|
||
# non-batched tests | ||
k = 100 # data dimension size | ||
|
||
# random triangular matrix inversion | ||
at = torch.rand((k, k)) | ||
# at += torch.eye(k) | ||
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable | ||
at = torch.triu(at).to(tdev) | ||
|
||
ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True) | ||
|
||
a = ht.factories.asarray(at, copy=True) | ||
c = ht.factories.asarray(ct, copy=True) | ||
b = ht.eye(k) | ||
|
||
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): | ||
a.resplit_(s0) | ||
b.resplit_(s1) | ||
|
||
res = ht.linalg.solve_triangular(a, b) | ||
self.assertTrue(ht.allclose(res, c)) | ||
|
||
# triangular ones inversion | ||
# for this test case, the results should be exact | ||
at = torch.triu(torch.ones_like(at)).to(tdev) | ||
ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True) | ||
|
||
a = ht.factories.asarray(at, copy=True) | ||
c = ht.factories.asarray(ct, copy=True) | ||
|
||
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): | ||
a.resplit_(s0) | ||
b.resplit_(s1) | ||
|
||
res = ht.linalg.solve_triangular(a, b) | ||
self.assertTrue(ht.equal(res, c)) | ||
|
||
# batched tests | ||
batch_shape = (10,) # batch dimensions shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we test higher-dimensional batches as well, i.e. 6-dimensional with split=3? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea, I have done so with 22507ae. |
||
# batch_shape = tuple() # no batch dimensions | ||
m = 100 # data dimension size | ||
|
||
at = torch.rand((*batch_shape, m, m)) | ||
# at += torch.eye(k) | ||
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable | ||
at = torch.triu(at).to(tdev) | ||
bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev) | ||
|
||
ct = torch.linalg.solve_triangular(at, bt, upper=True) | ||
|
||
a = ht.factories.asarray(at, copy=True) | ||
c = ht.factories.asarray(ct, copy=True) | ||
b = ht.factories.asarray(bt, copy=True) | ||
|
||
# split in linalg dimension or none | ||
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): | ||
a.resplit_(s0) | ||
b.resplit_(s1) | ||
|
||
res = ht.linalg.solve_triangular(a, b) | ||
|
||
self.assertTrue(ht.allclose(c, res)) | ||
|
||
# split in batch dimension | ||
s = 0 | ||
a.resplit_(s) | ||
b.resplit_(s) | ||
c.resplit_(s) | ||
|
||
res = ht.linalg.solve_triangular(a, b) | ||
|
||
self.assertTrue(ht.allclose(c, res)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit confusing, also not following our documentation scheme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the docstring according to our conventions.