Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add solver for triangular systems #1504

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 192 additions & 1 deletion heat/core/linalg/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

__all__ = ["cg", "lanczos"]
__all__ = ["cg", "lanczos", "solve_triangular"]


def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
Expand Down Expand Up @@ -270,3 +270,194 @@ def lanczos(
V.resplit_(axis=None)

return V, T


def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray:
"""
This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns `x` in `Ax = b`, where `A` is a (possibly batched) upper triangular matrix and
`b` a (possibly batched) vector or matrix of suitable shape, both provided as input to the function.
The implementation builts on the corresponding solver in PyTorch and implements an memory-distributed, MPI-parallel block-wise version thereof.
Parameters
----------
A : DNDarray
An upper triangular invertible square (n x n) matrix or a batch thereof, i.e. a ``DNDarray`` of shape `(..., n, n)`.
b : DNDarray
a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A.
(Batched) Vectors have to be provided as ... x n x 1 matrices and the split dimension of b must the second last dimension if not None.
Note
---------
Since such a check might be computationally expensive, we do not check whether A is indeed upper triangular.
If you require such a check, please open an issue on our GitHub page and request this feature.
"""
if not isinstance(A, DNDarray) or not isinstance(b, DNDarray):
raise TypeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.")
if not A.ndim >= 2:
raise ValueError("A needs to be a (batched) matrix.")
if not b.ndim == A.ndim:
raise ValueError("b needs to have the same number of (batch) dimensions as A.")
if not A.shape[-2] == A.shape[-1]:
raise ValueError("A needs to be a (batched) square matrix.")

batch_dim = A.ndim - 2
batch_shape = A.shape[:batch_dim]

if not A.shape[:batch_dim] == b.shape[:batch_dim]:
raise ValueError("Batch dimensions of A and b must be of the same shape.")
if b.split == batch_dim + 1:
raise ValueError("split=1 is not allowed for the right hand side.")
if not b.shape[batch_dim] == A.shape[-1]:
raise ValueError("Dimension mismatch of A and b.")

if (
A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim
): # batch split
if A.split != b.split:
raise ValueError(
"If a split dimension is a batch dimension, A and b must have the same split dimension. A possible solution would be a resplit of A or b to the same split dimension."
)
else:
if (
A.split is not None and b.split is not None
): # both la dimensions split --> b.split = batch_dim
# TODO remove?
if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]):
raise RuntimeError(
"The process-local arrays of A and b have different sizes along the splitted axis. This is most likely due to one of the DNDarrays being in unbalanced state. \n Consider using `A.is_balanced(force_check=True)` and `b.is_balanced(force_check=True)` to check if A and b are balanced; \n then call `A.balance_()` and/or `b.balance_()` in order to achieve equal local shapes along the split axis before applying `solve_triangular`."
)

comm = A.comm
dev = A.device
tdev = dev.torch_device

nprocs = comm.Get_size()

if A.split is None: # A not split
if b.split is None:
x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)

return factories.array(x, dtype=b.dtype, device=dev, comm=comm)
else: # A not split, b.split == -2
b_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=tdev),
torch.cumsum(b.lshape_map[:, -2], 0),
]
)

btilde_loc = b.larray.clone()
A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]]

x = factories.zeros_like(b, device=dev, comm=comm)

for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0

res_send = torch.empty(0)
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)

if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A_loc @ x.larray

comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)

if comm.rank < i:
btilde_loc -= res_recv

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True
)

return x

if A.split < batch_dim: # batch split
x = factories.zeros_like(b, device=dev, comm=comm, split=A.split)
x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)

return x

if A.split >= batch_dim: # both splits in la dims
A_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=tdev),
torch.cumsum(A.lshape_map[:, A.split], 0),
]
)

if b.split is None:
btilde_loc = b.larray[
..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], :
].clone()
else: # b is split at la dim 0
btilde_loc = b.larray.clone()

x = factories.zeros_like(
b, device=dev, comm=comm, split=batch_dim
) # split at la dim 0 in case b is not split

if A.split == batch_dim + 1:
for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0

res_send = torch.empty(0)
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)

if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A.larray @ x.larray

comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)

if comm.rank < i:
btilde_loc -= res_recv

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True
)

else: # split dim is la dim 0
for i in range(nprocs - 1, 0, -1):
idims = tuple(x.lshape_map[i])
if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]],
btilde_loc,
upper=True,
)
x_from_i = x.larray
else:
x_from_i = torch.zeros(
idims,
dtype=b.dtype.torch_type(),
device=tdev,
)

comm.Bcast(x_from_i, root=i)

if comm.rank < i:
btilde_loc -= (
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i
)

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True
)

return x
131 changes: 131 additions & 0 deletions heat/core/linalg/tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,134 @@ def test_lanczos(self):
with self.assertRaises(NotImplementedError):
A = ht.random.randn(10, 10, split=1)
V, T = ht.lanczos(A, m=3)

def test_solve_triangular(self):
torch.manual_seed(42)
tdev = ht.get_device().torch_device

# non-batched tests
k = 100 # data dimension size

# random triangular matrix inversion
at = torch.rand((k, k))
# at += torch.eye(k)
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
at = torch.triu(at).to(tdev)

ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)
b = ht.eye(k)

# exceptions
with self.assertRaises(TypeError): # invalid datatype for b
ht.linalg.solve_triangular(a, 42)

with self.assertRaises(ValueError): # a no matrix, not enough dimensions
ht.linalg.solve_triangular(a[1], b)

with self.assertRaises(ValueError): # a and b different number of dimensions
ht.linalg.solve_triangular(a, b[1])

with self.assertRaises(ValueError): # a no square matrix
ht.linalg.solve_triangular(a[1:, ...], b)

with self.assertRaises(ValueError): # split=1 for b
b.resplit_(-1)
ht.linalg.solve_triangular(a, b)

b.resplit_(0)
with self.assertRaises(ValueError): # dimension mismatch
ht.linalg.solve_triangular(a, b[1:, ...])

for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)
self.assertTrue(ht.allclose(res, c))

# triangular ones inversion
# for this test case, the results should be exact
at = torch.triu(torch.ones_like(at)).to(tdev)
ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)

for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)
self.assertTrue(ht.equal(res, c))

# batched tests
batch_shapes = [
(10,),
(
4,
4,
4,
20,
),
]
m = 100 # data dimension size

# exceptions
batch_shape = batch_shapes[1]

at = torch.rand((*batch_shape, m, m))
# at += torch.eye(k)
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
at = torch.triu(at).to(tdev)
bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev)

ct = torch.linalg.solve_triangular(at, bt, upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)
b = ht.factories.asarray(bt, copy=True)

with self.assertRaises(ValueError): # batch dimensions of different shapes
ht.linalg.solve_triangular(a[1:, ...], b)

with self.assertRaises(ValueError): # different batched split dimensions
a.resplit_(0)
b.resplit_(1)
ht.linalg.solve_triangular(a, b)

for batch_shape in batch_shapes:
# batch_shape = tuple() # no batch dimensions

at = torch.rand((*batch_shape, m, m))
# at += torch.eye(k)
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
at = torch.triu(at).to(tdev)
bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev)

ct = torch.linalg.solve_triangular(at, bt, upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)
b = ht.factories.asarray(bt, copy=True)

# split in linalg dimension or none
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)

self.assertTrue(ht.allclose(c, res))

# split in batch dimension
s = len(batch_shape) - 1
a.resplit_(s)
b.resplit_(s)
c.resplit_(s)

res = ht.linalg.solve_triangular(a, b)

self.assertTrue(ht.allclose(c, res))
Loading