Skip to content

Commit

Permalink
Merge pull request #1236 from helmholtz-analytics/features/1096-Provi…
Browse files Browse the repository at this point in the history
…de_a_solver_for_triangular_systems

Add solver for distributed triangular systems
  • Loading branch information
mrfh92 authored Apr 12, 2024
2 parents ef44340 + 6589754 commit b070199
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 1 deletion.
196 changes: 195 additions & 1 deletion heat/core/linalg/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from ..dndarray import DNDarray
from ..sanitation import sanitize_out
from typing import List, Dict, Any, TypeVar, Union, Tuple, Optional
from .. import factories

import torch

__all__ = ["cg", "lanczos"]
__all__ = ["cg", "lanczos", "solve_triangular"]


def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
Expand Down Expand Up @@ -269,3 +270,196 @@ def lanczos(
V.resplit_(axis=None)

return V, T


def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray:
"""
This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns `x` in `Ax = b`, where `A` is a (possibly batched) upper triangular matrix and
`b` a (possibly batched) vector or matrix of suitable shape, both provided as input to the function.
The implementation builts on the corresponding solver in PyTorch and implements an memory-distributed, MPI-parallel block-wise version thereof.
Parameters
----------
A : DNDarray
An upper triangular invertible square (n x n) matrix or a batch thereof, i.e. a ``DNDarray`` of shape `(..., n, n)`.
b : DNDarray
a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A.
(Batched) Vectors have to be provided as ... x n x 1 matrices and the split dimension of b must the second last dimension if not None.
Note
---------
Since such a check might be computationally expensive, we do not check whether A is indeed upper triangular.
If you require such a check, please open an issue on our GitHub page and request this feature.
"""
if not isinstance(A, DNDarray) or not isinstance(b, DNDarray):
raise TypeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.")
if not A.ndim >= 2:
raise ValueError("A needs to be a (batched) matrix.")
if not b.ndim == A.ndim:
raise ValueError("b needs to have the same number of (batch) dimensions as A.")
if not A.shape[-2] == A.shape[-1]:
raise ValueError("A needs to be a (batched) square matrix.")

batch_dim = A.ndim - 2
batch_shape = A.shape[:batch_dim]

if not A.shape[:batch_dim] == b.shape[:batch_dim]:
raise ValueError("Batch dimensions of A and b must be of the same shape.")
if b.split == batch_dim + 1:
raise ValueError("split=1 is not allowed for the right hand side.")
if not b.shape[batch_dim] == A.shape[-1]:
raise ValueError("Dimension mismatch of A and b.")

if (
A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim
): # batch split
if A.split != b.split:
raise ValueError(
"If a split dimension is a batch dimension, A and b must have the same split dimension. A possible solution would be a resplit of A or b to the same split dimension."
)
else:
if (
A.split is not None and b.split is not None
): # both la dimensions split --> b.split = batch_dim
# TODO remove?
if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]):
raise RuntimeError(
"The process-local arrays of A and b have different sizes along the splitted axis. This is most likely due to one of the DNDarrays being in unbalanced state. \n Consider using `A.is_balanced(force_check=True)` and `b.is_balanced(force_check=True)` to check if A and b are balanced; \n then call `A.balance_()` and/or `b.balance_()` in order to achieve equal local shapes along the split axis before applying `solve_triangular`."
)

comm = A.comm
dev = A.device
tdev = dev.torch_device

nprocs = comm.Get_size()

if A.split is None: # A not split
if b.split is None:
x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)

return factories.array(x, dtype=b.dtype, device=dev, comm=comm)
else: # A not split, b.split == -2
b_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=tdev),
torch.cumsum(b.lshape_map[:, -2], 0),
]
)

btilde_loc = b.larray.clone()
A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]]

x = factories.zeros_like(b, device=dev, comm=comm)

for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0

res_send = torch.empty(0)
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)

if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A_loc @ x.larray

comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)

if comm.rank < i:
btilde_loc -= res_recv

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True
)

return x

if A.split < batch_dim: # batch split
x = factories.zeros_like(b, device=dev, comm=comm, split=A.split)
x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)

return x

if A.split >= batch_dim: # both splits in la dims
A_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=tdev),
torch.cumsum(A.lshape_map[:, A.split], 0),
]
)

if b.split is None:
btilde_loc = b.larray[
..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], :
].clone()
else: # b is split at la dim 0
btilde_loc = b.larray.clone()

x = factories.zeros_like(
b, device=dev, comm=comm, split=batch_dim
) # split at la dim 0 in case b is not split

if A.split == batch_dim + 1:
for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0

res_send = torch.empty(0)
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)

if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A.larray @ x.larray

comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)

if comm.rank < i:
btilde_loc -= res_recv

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True
)

else: # split dim is la dim 0
for i in range(nprocs - 1, 0, -1):
idims = tuple(x.lshape_map[i])
if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]],
btilde_loc,
upper=True,
)
x_from_i = x.larray
else:
x_from_i = torch.zeros(
idims,
dtype=b.dtype.torch_type(),
device=tdev,
)

comm.Bcast(x_from_i, root=i)

if comm.rank < i:
btilde_loc -= (
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i
)

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True
)

return x
87 changes: 87 additions & 0 deletions heat/core/linalg/tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,90 @@ def test_lanczos(self):
with self.assertRaises(NotImplementedError):
A = ht.random.randn(10, 10, split=1)
V, T = ht.lanczos(A, m=3)

def test_solve_triangular(self):
torch.manual_seed(42)
tdev = ht.get_device().torch_device

# non-batched tests
k = 100 # data dimension size

# random triangular matrix inversion
at = torch.rand((k, k))
# at += torch.eye(k)
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
at = torch.triu(at).to(tdev)

ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)
b = ht.eye(k)

for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)
self.assertTrue(ht.allclose(res, c))

# triangular ones inversion
# for this test case, the results should be exact
at = torch.triu(torch.ones_like(at)).to(tdev)
ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)

for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)
self.assertTrue(ht.equal(res, c))

# batched tests
batch_shapes = [
(10,),
(
4,
4,
4,
20,
),
]
m = 100 # data dimension size

for batch_shape in batch_shapes:
# batch_shape = tuple() # no batch dimensions

at = torch.rand((*batch_shape, m, m))
# at += torch.eye(k)
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
at = torch.triu(at).to(tdev)
bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev)

ct = torch.linalg.solve_triangular(at, bt, upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)
b = ht.factories.asarray(bt, copy=True)

# split in linalg dimension or none
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)

self.assertTrue(ht.allclose(c, res))

# split in batch dimension
s = len(batch_shape) - 1
a.resplit_(s)
b.resplit_(s)
c.resplit_(s)

res = ht.linalg.solve_triangular(a, b)

self.assertTrue(ht.allclose(c, res))

0 comments on commit b070199

Please sign in to comment.