diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 19f29d710f..bfa99ee914 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -6,10 +6,11 @@ from ..dndarray import DNDarray from ..sanitation import sanitize_out from typing import List, Dict, Any, TypeVar, Union, Tuple, Optional +from .. import factories import torch -__all__ = ["cg", "lanczos"] +__all__ = ["cg", "lanczos", "solve_triangular"] def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: @@ -269,3 +270,196 @@ def lanczos( V.resplit_(axis=None) return V, T + + +def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: + """ + This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns `x` in `Ax = b`, where `A` is a (possibly batched) upper triangular matrix and + `b` a (possibly batched) vector or matrix of suitable shape, both provided as input to the function. + The implementation builts on the corresponding solver in PyTorch and implements an memory-distributed, MPI-parallel block-wise version thereof. + + Parameters + ---------- + A : DNDarray + An upper triangular invertible square (n x n) matrix or a batch thereof, i.e. a ``DNDarray`` of shape `(..., n, n)`. + b : DNDarray + a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A. + (Batched) Vectors have to be provided as ... x n x 1 matrices and the split dimension of b must the second last dimension if not None. + + Note + --------- + Since such a check might be computationally expensive, we do not check whether A is indeed upper triangular. + If you require such a check, please open an issue on our GitHub page and request this feature. + """ + if not isinstance(A, DNDarray) or not isinstance(b, DNDarray): + raise TypeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.") + if not A.ndim >= 2: + raise ValueError("A needs to be a (batched) matrix.") + if not b.ndim == A.ndim: + raise ValueError("b needs to have the same number of (batch) dimensions as A.") + if not A.shape[-2] == A.shape[-1]: + raise ValueError("A needs to be a (batched) square matrix.") + + batch_dim = A.ndim - 2 + batch_shape = A.shape[:batch_dim] + + if not A.shape[:batch_dim] == b.shape[:batch_dim]: + raise ValueError("Batch dimensions of A and b must be of the same shape.") + if b.split == batch_dim + 1: + raise ValueError("split=1 is not allowed for the right hand side.") + if not b.shape[batch_dim] == A.shape[-1]: + raise ValueError("Dimension mismatch of A and b.") + + if ( + A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim + ): # batch split + if A.split != b.split: + raise ValueError( + "If a split dimension is a batch dimension, A and b must have the same split dimension. A possible solution would be a resplit of A or b to the same split dimension." + ) + else: + if ( + A.split is not None and b.split is not None + ): # both la dimensions split --> b.split = batch_dim + # TODO remove? + if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]): + raise RuntimeError( + "The process-local arrays of A and b have different sizes along the splitted axis. This is most likely due to one of the DNDarrays being in unbalanced state. \n Consider using `A.is_balanced(force_check=True)` and `b.is_balanced(force_check=True)` to check if A and b are balanced; \n then call `A.balance_()` and/or `b.balance_()` in order to achieve equal local shapes along the split axis before applying `solve_triangular`." + ) + + comm = A.comm + dev = A.device + tdev = dev.torch_device + + nprocs = comm.Get_size() + + if A.split is None: # A not split + if b.split is None: + x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) + + return factories.array(x, dtype=b.dtype, device=dev, comm=comm) + else: # A not split, b.split == -2 + b_lshapes_cum = torch.hstack( + [ + torch.zeros(1, dtype=torch.int32, device=tdev), + torch.cumsum(b.lshape_map[:, -2], 0), + ] + ) + + btilde_loc = b.larray.clone() + A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]] + + x = factories.zeros_like(b, device=dev, comm=comm) + + for i in range(nprocs - 1, 0, -1): + count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy() + displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy() + count[i:] = 0 # nothing to send, as there are only zero rows + displ[i:] = 0 + + res_send = torch.empty(0) + res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev) + + if comm.rank == i: + x.larray = torch.linalg.solve_triangular( + A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :], + btilde_loc, + upper=True, + ) + res_send = A_loc @ x.larray + + comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim) + + if comm.rank < i: + btilde_loc -= res_recv + + if comm.rank == 0: + x.larray = torch.linalg.solve_triangular( + A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True + ) + + return x + + if A.split < batch_dim: # batch split + x = factories.zeros_like(b, device=dev, comm=comm, split=A.split) + x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) + + return x + + if A.split >= batch_dim: # both splits in la dims + A_lshapes_cum = torch.hstack( + [ + torch.zeros(1, dtype=torch.int32, device=tdev), + torch.cumsum(A.lshape_map[:, A.split], 0), + ] + ) + + if b.split is None: + btilde_loc = b.larray[ + ..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], : + ].clone() + else: # b is split at la dim 0 + btilde_loc = b.larray.clone() + + x = factories.zeros_like( + b, device=dev, comm=comm, split=batch_dim + ) # split at la dim 0 in case b is not split + + if A.split == batch_dim + 1: + for i in range(nprocs - 1, 0, -1): + count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy() + displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy() + count[i:] = 0 # nothing to send, as there are only zero rows + displ[i:] = 0 + + res_send = torch.empty(0) + res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev) + + if comm.rank == i: + x.larray = torch.linalg.solve_triangular( + A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :], + btilde_loc, + upper=True, + ) + res_send = A.larray @ x.larray + + comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim) + + if comm.rank < i: + btilde_loc -= res_recv + + if comm.rank == 0: + x.larray = torch.linalg.solve_triangular( + A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True + ) + + else: # split dim is la dim 0 + for i in range(nprocs - 1, 0, -1): + idims = tuple(x.lshape_map[i]) + if comm.rank == i: + x.larray = torch.linalg.solve_triangular( + A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]], + btilde_loc, + upper=True, + ) + x_from_i = x.larray + else: + x_from_i = torch.zeros( + idims, + dtype=b.dtype.torch_type(), + device=tdev, + ) + + comm.Bcast(x_from_i, root=i) + + if comm.rank < i: + btilde_loc -= ( + A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i + ) + + if comm.rank == 0: + x.larray = torch.linalg.solve_triangular( + A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True + ) + + return x diff --git a/heat/core/linalg/tests/test_solver.py b/heat/core/linalg/tests/test_solver.py index f8f9889a9d..a4c473ac8f 100644 --- a/heat/core/linalg/tests/test_solver.py +++ b/heat/core/linalg/tests/test_solver.py @@ -135,3 +135,90 @@ def test_lanczos(self): with self.assertRaises(NotImplementedError): A = ht.random.randn(10, 10, split=1) V, T = ht.lanczos(A, m=3) + + def test_solve_triangular(self): + torch.manual_seed(42) + tdev = ht.get_device().torch_device + + # non-batched tests + k = 100 # data dimension size + + # random triangular matrix inversion + at = torch.rand((k, k)) + # at += torch.eye(k) + at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable + at = torch.triu(at).to(tdev) + + ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True) + + a = ht.factories.asarray(at, copy=True) + c = ht.factories.asarray(ct, copy=True) + b = ht.eye(k) + + for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): + a.resplit_(s0) + b.resplit_(s1) + + res = ht.linalg.solve_triangular(a, b) + self.assertTrue(ht.allclose(res, c)) + + # triangular ones inversion + # for this test case, the results should be exact + at = torch.triu(torch.ones_like(at)).to(tdev) + ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True) + + a = ht.factories.asarray(at, copy=True) + c = ht.factories.asarray(ct, copy=True) + + for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): + a.resplit_(s0) + b.resplit_(s1) + + res = ht.linalg.solve_triangular(a, b) + self.assertTrue(ht.equal(res, c)) + + # batched tests + batch_shapes = [ + (10,), + ( + 4, + 4, + 4, + 20, + ), + ] + m = 100 # data dimension size + + for batch_shape in batch_shapes: + # batch_shape = tuple() # no batch dimensions + + at = torch.rand((*batch_shape, m, m)) + # at += torch.eye(k) + at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable + at = torch.triu(at).to(tdev) + bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev) + + ct = torch.linalg.solve_triangular(at, bt, upper=True) + + a = ht.factories.asarray(at, copy=True) + c = ht.factories.asarray(ct, copy=True) + b = ht.factories.asarray(bt, copy=True) + + # split in linalg dimension or none + for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): + a.resplit_(s0) + b.resplit_(s1) + + res = ht.linalg.solve_triangular(a, b) + + self.assertTrue(ht.allclose(c, res)) + + # split in batch dimension + s = len(batch_shape) - 1 + a.resplit_(s) + b.resplit_(s) + c.resplit_(s) + + res = ht.linalg.solve_triangular(a, b) + + self.assertTrue(ht.allclose(c, res))