Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add solver for distributed triangular systems #1236

Merged
merged 36 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8ffa51c
added my first ideas for an implementation of solve_triangular (not a…
Oct 10, 2023
8003642
readme test
FOsterfeld Oct 10, 2023
e6c88e8
dummy commit to find out whether CI works...
mrfh92 Oct 12, 2023
782c7b2
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Dec 18, 2023
524e294
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Jan 3, 2024
3749dbf
small improvements in solve_triangular for the case A.split == 1
FOsterfeld Jan 16, 2024
de2fe73
small improvements in solve_triangular for the case A.split == 0
FOsterfeld Jan 17, 2024
85fd3af
added tests for solve_triangular
FOsterfeld Jan 17, 2024
bcac92e
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Jan 22, 2024
aa27c2c
implemented the case A.split == 1 with scatterv instead of bcast to r…
FOsterfeld Feb 6, 2024
6fe7ed0
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Feb 6, 2024
7accc48
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Feb 20, 2024
7f2ccc5
implemented solve_triangular for batched inputs and b.split = None
FOsterfeld Feb 20, 2024
7aa3f8d
Merge branch 'features/1096-Provide_a_solver_for_triangular_systems' …
FOsterfeld Feb 20, 2024
8d2af3e
implemented solve_triangular for A.split = None, b.split = -2
FOsterfeld Feb 27, 2024
fd7e2d7
added tests for batched input of solve_triangular
FOsterfeld Feb 27, 2024
c8b0890
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Feb 27, 2024
4591260
fixed a bug when not using the cpu
FOsterfeld Feb 29, 2024
c62434e
Merge branch 'features/1096-Provide_a_solver_for_triangular_systems' …
FOsterfeld Feb 29, 2024
25bbabd
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
FOsterfeld Feb 29, 2024
1bcd97f
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
FOsterfeld Feb 29, 2024
cf84b88
specify device
FOsterfeld Feb 29, 2024
18d3f00
fixed on gpu
FOsterfeld Mar 5, 2024
4f3cd08
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
FOsterfeld Mar 5, 2024
bac8fc7
improved exception handling for invalid input data
FOsterfeld Mar 26, 2024
22507ae
added tests for higher dimensional batches
FOsterfeld Mar 26, 2024
96fa81f
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Apr 3, 2024
bfac0c3
changed error message for non-equal local sizes along split axis
Apr 3, 2024
db370fe
updated docstring according to our conventions
Apr 3, 2024
f790e23
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
ClaudiaComito Apr 11, 2024
af10937
Update heat/core/linalg/solver.py
mrfh92 Apr 11, 2024
795053f
Update heat/core/linalg/solver.py
mrfh92 Apr 11, 2024
ba3578a
Update heat/core/linalg/solver.py
mrfh92 Apr 11, 2024
e248a42
Update solver.py
mrfh92 Apr 11, 2024
d9edee7
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Apr 11, 2024
6589754
Merge branch 'main' into features/1096-Provide_a_solver_for_triangula…
mrfh92 Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 195 additions & 1 deletion heat/core/linalg/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from ..dndarray import DNDarray
from ..sanitation import sanitize_out
from typing import List, Dict, Any, TypeVar, Union, Tuple, Optional
from .. import factories

import torch

__all__ = ["cg", "lanczos"]
__all__ = ["cg", "lanczos", "solve_triangular"]


def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
Expand Down Expand Up @@ -269,3 +270,196 @@ def lanczos(
V.resplit_(axis=None)

return V, T


def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray:
"""
This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns x in Ax = b, where A is a (possibly batched) upper triangular matrix and
mrfh92 marked this conversation as resolved.
Show resolved Hide resolved
b a (possibly batched) vector or matrix of suitable shape, both provided as input to the function.
mrfh92 marked this conversation as resolved.
Show resolved Hide resolved
The implementation builts on the corresponding solver in PyTorch and implements a block-wise version thereof.
mrfh92 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
A : DNDarray
An upper triangular (possibly batched) invertible square (n x n) matrix, i.e. an DNDarray of shape (..., n, n).
mrfh92 marked this conversation as resolved.
Show resolved Hide resolved
b : DNDarray
a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A.
Vectors have to be provided as n x 1 matrices and the split dimension of b must the second last dimension if not None.
mrfh92 marked this conversation as resolved.
Show resolved Hide resolved

Note
---------
Since such a check might be computationally expensive, we do not check whether A is indeed upper triangular.
If you require such a check, please open an issue on our GitHub page and request this feature.
"""
if not isinstance(A, DNDarray) or not isinstance(b, DNDarray):
raise TypeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.")
if not A.ndim >= 2:
raise ValueError("A needs to be a (batched) matrix.")
if not b.ndim == A.ndim:
raise ValueError("b needs to have the same number of (batch) dimensions as A.")
if not A.shape[-2] == A.shape[-1]:
raise ValueError("A needs to be a (batched) square matrix.")

batch_dim = A.ndim - 2
batch_shape = A.shape[:batch_dim]

if not A.shape[:batch_dim] == b.shape[:batch_dim]:
raise ValueError("Batch dimensions of A and b must be of the same shape.")
if b.split == batch_dim + 1:
raise ValueError("split=1 is not allowed for the right hand side.")
if not b.shape[batch_dim] == A.shape[-1]:
raise ValueError("Dimension mismatch of A and b.")

if (
A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim
): # batch split
if A.split != b.split:
raise ValueError(
"If a split dimension is a batch dimension, A and b must have the same split dimension. A possible solution would be a resplit of A or b to the same split dimension."
)
else:
if (
A.split is not None and b.split is not None
): # both la dimensions split --> b.split = batch_dim
# TODO remove?
if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]):
raise RuntimeError(
"The process-local arrays of A and b have different sizes along the splitted axis. This is most likely due to one of the DNDarrays being in unbalanced state. \n Consider using `A.is_balanced(force_check=True)` and `b.is_balanced(force_check=True)` to check if A and b are balanced; \n then call `A.balance_()` and/or `b.balance_()` in order to achieve equal local shapes along the split axis before applying `solve_triangular`."
)

comm = A.comm
dev = A.device
tdev = dev.torch_device

nprocs = comm.Get_size()

if A.split is None: # A not split
if b.split is None:
x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)

return factories.array(x, dtype=b.dtype, device=dev, comm=comm)
else: # A not split, b.split == -2
b_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=tdev),
torch.cumsum(b.lshape_map[:, -2], 0),
]
)

btilde_loc = b.larray.clone()
A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]]

x = factories.zeros_like(b, device=dev, comm=comm)

for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0

res_send = torch.empty(0)
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)

if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A_loc @ x.larray

comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)

if comm.rank < i:
btilde_loc -= res_recv

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True
)

return x

if A.split < batch_dim: # batch split
x = factories.zeros_like(b, device=dev, comm=comm, split=A.split)
x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True)

return x

if A.split >= batch_dim: # both splits in la dims
A_lshapes_cum = torch.hstack(
[
torch.zeros(1, dtype=torch.int32, device=tdev),
torch.cumsum(A.lshape_map[:, A.split], 0),
]
)

if b.split is None:
btilde_loc = b.larray[
..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], :
].clone()
else: # b is split at la dim 0
btilde_loc = b.larray.clone()

x = factories.zeros_like(
b, device=dev, comm=comm, split=batch_dim
) # split at la dim 0 in case b is not split

if A.split == batch_dim + 1:
for i in range(nprocs - 1, 0, -1):
count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy()
displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy()
count[i:] = 0 # nothing to send, as there are only zero rows
displ[i:] = 0

res_send = torch.empty(0)
res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev)

if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :],
btilde_loc,
upper=True,
)
res_send = A.larray @ x.larray

comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim)

if comm.rank < i:
btilde_loc -= res_recv

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True
)

else: # split dim is la dim 0
for i in range(nprocs - 1, 0, -1):
idims = tuple(x.lshape_map[i])
if comm.rank == i:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]],
btilde_loc,
upper=True,
)
x_from_i = x.larray
else:
x_from_i = torch.zeros(
idims,
dtype=b.dtype.torch_type(),
device=tdev,
)

comm.Bcast(x_from_i, root=i)

if comm.rank < i:
btilde_loc -= (
A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i
)

if comm.rank == 0:
x.larray = torch.linalg.solve_triangular(
A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True
)

return x
87 changes: 87 additions & 0 deletions heat/core/linalg/tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,90 @@ def test_lanczos(self):
with self.assertRaises(NotImplementedError):
A = ht.random.randn(10, 10, split=1)
V, T = ht.lanczos(A, m=3)

def test_solve_triangular(self):
torch.manual_seed(42)
tdev = ht.get_device().torch_device

# non-batched tests
k = 100 # data dimension size

# random triangular matrix inversion
at = torch.rand((k, k))
# at += torch.eye(k)
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
at = torch.triu(at).to(tdev)

ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)
b = ht.eye(k)

for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)
self.assertTrue(ht.allclose(res, c))

# triangular ones inversion
# for this test case, the results should be exact
at = torch.triu(torch.ones_like(at)).to(tdev)
ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)

for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)
self.assertTrue(ht.equal(res, c))

# batched tests
batch_shapes = [
(10,),
(
4,
4,
4,
20,
),
]
m = 100 # data dimension size

for batch_shape in batch_shapes:
# batch_shape = tuple() # no batch dimensions

at = torch.rand((*batch_shape, m, m))
# at += torch.eye(k)
at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable
at = torch.triu(at).to(tdev)
bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev)

ct = torch.linalg.solve_triangular(at, bt, upper=True)

a = ht.factories.asarray(at, copy=True)
c = ht.factories.asarray(ct, copy=True)
b = ht.factories.asarray(bt, copy=True)

# split in linalg dimension or none
for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2):
a.resplit_(s0)
b.resplit_(s1)

res = ht.linalg.solve_triangular(a, b)

self.assertTrue(ht.allclose(c, res))

# split in batch dimension
s = len(batch_shape) - 1
a.resplit_(s)
b.resplit_(s)
c.resplit_(s)

res = ht.linalg.solve_triangular(a, b)

self.assertTrue(ht.allclose(c, res))