From 8ffa51c1d72aa4bdf503d4007d23850f9b52166b Mon Sep 17 00:00:00 2001 From: Hoppe Date: Tue, 10 Oct 2023 10:55:07 +0200 Subject: [PATCH 01/21] added my first ideas for an implementation of solve_triangular (not all cases covered, may be inefficient, no tests so far!) --- heat/core/linalg/solver.py | 93 +++++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 3273fc739c..383103321c 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -5,10 +5,11 @@ from ..dndarray import DNDarray from ..sanitation import sanitize_out from typing import List, Dict, Any, TypeVar, Union, Tuple, Optional +from .. import factories import torch -__all__ = ["cg", "lanczos"] +__all__ = ["cg", "lanczos", "solve_triangular"] def cg(A: DNDarray, b: DNDarray, x0: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: @@ -268,3 +269,93 @@ def lanczos( V.resplit_(axis=None) return V, T + + +def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: + """ + My triangular solver, based on blockwise trisolves... + """ + if not isinstance(A, DNDarray) or not isinstance(b, DNDarray): + raise RuntimeError("Arguments need to be a DNDarrays.") + if not A.ndim == 2: + raise RuntimeError("A needs to be a 2D matrix") + if not b.ndim <= 2: + raise RuntimeError("b needs to be a vector (1D) or a matrix (2D)") + if not A.shape[0] == A.shape[1]: + raise RuntimeError("A needs to be a square matrix.") + if not (b.split == 0 or b.split is None): + raise RuntimeError("split=1 is not allowed for the right hand side.") + if not b.shape[0] == A.shape[0]: + raise RuntimeError("Dimension mismath of A and b.") + if ( + A.split is not None + and b.split is not None + and not all(A.lshape_map[:, A.split] == b.lshape_map[:, 0]) + ): + raise RuntimeError("Local arrays of A and b have different sizes.") + + if A.split is None: + x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) + return factories.array(x, dtype=b.dtype, device=b.device, comm=b.comm) + + nprocs = A.comm.Get_size() + A_lshapes_cum = torch.hstack( + [torch.zeros(1, dtype=torch.int32), torch.cumsum(A.lshape_map[:, A.split], 0)] + ) + btilde_loc = b.larray.clone() + x = factories.zeros_like(b, comm=b.comm) + + if A.split == 1: + for i in range(nprocs - 1, -1, -1): + res = torch.zeros( + (A_lshapes_cum[i], b.shape[1]), + dtype=b.dtype.torch_type(), + device=b.device.torch_device, + ) + if A.comm.rank == i: + x.larray = torch.linalg.solve_triangular( + A.larray[A_lshapes_cum[i] : A_lshapes_cum[i + 1], :], btilde_loc, upper=True + ) + res = A.larray[: A_lshapes_cum[i], :] @ x.larray + if i > 0: + req = A.comm.Ibcast(res, root=i) + req.Wait() + if A.comm.rank < i: + j = A.comm.rank + btilde_loc -= res[A_lshapes_cum[j] : A_lshapes_cum[j + 1], :] + + # if A.split == 1: + # for i in range(nprocs-1,-1,-1): + # count = b.lshape[0]*b.lshape[1] + # displ = (A_lshapes_cum*b.shape[1]).numpy()[:-1] + # res_send = None + # res_recv = torch.zeros(b.lshape[0]*b.lshape[1], dtype=b.dtype.torch_type(), device=b.device.torch_device) + # if A.comm.rank == i: + # x.larray = torch.linalg.solve_triangular(A.larray[A_lshapes_cum[i]:A_lshapes_cum[i+1],:],btilde_loc,upper=True) + # res_send = (A.larray @ x.larray).flatten() + # if i > 0: + # A.comm.handle.Scatterv([res_send, count, displ, MPI.DOUBLE], res_recv, root=i) + # if A.comm.rank < i: + # j = A.comm.rank + # btilde_loc -= res_recv.reshape(b.lshape) + + else: + for i in range(nprocs - 1, -1, -1): + x_from_i = torch.zeros( + (x.lshape_map[i, 0], x.lshape_map[i, 1]), + dtype=b.dtype.torch_type(), + device=b.device.torch_device, + ) + if A.comm.rank == i: + x.larray = torch.linalg.solve_triangular( + A.larray[:, A_lshapes_cum[i] : A_lshapes_cum[i + 1]], btilde_loc, upper=True + ) + x_from_i = x.larray + if i > 0: + req = A.comm.Ibcast(x_from_i, root=i) + req.Wait() + if A.comm.rank < i: + j = A.comm.rank + btilde_loc -= A.larray[:, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i + + return x From 8003642fe101e784dc082f0d6ce2073a702f4734 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 10 Oct 2023 12:55:53 +0200 Subject: [PATCH 02/21] readme test --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f19e8dd99c..cb96061611 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Heat is a distributed tensor framework for high performance data analytics. -# Project Status +## Project Status [![CPU/CUDA/ROCm tests](https://codebase.helmholtz.cloud/helmholtz-analytics/ci/badges/heat/base/pipeline.svg)](https://codebase.helmholtz.cloud/helmholtz-analytics/ci/-/commits/heat/base) [![Documentation Status](https://readthedocs.org/projects/heat/badge/?version=latest)](https://heat.readthedocs.io/en/latest/?badge=latest) From e6c88e8a4d0972852417d9995ef8d3facf5611fe Mon Sep 17 00:00:00 2001 From: Fabian Hoppe <112093564+mrfh92@users.noreply.github.com> Date: Thu, 12 Oct 2023 09:45:02 +0200 Subject: [PATCH 03/21] dummy commit to find out whether CI works... --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cb96061611..f19e8dd99c 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Heat is a distributed tensor framework for high performance data analytics. -## Project Status +# Project Status [![CPU/CUDA/ROCm tests](https://codebase.helmholtz.cloud/helmholtz-analytics/ci/badges/heat/base/pipeline.svg)](https://codebase.helmholtz.cloud/helmholtz-analytics/ci/-/commits/heat/base) [![Documentation Status](https://readthedocs.org/projects/heat/badge/?version=latest)](https://heat.readthedocs.io/en/latest/?badge=latest) From 3749dbff2898b0a4ca21f73e39a6bf07b01512c7 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 16 Jan 2024 18:24:35 +0100 Subject: [PATCH 04/21] small improvements in solve_triangular for the case A.split == 1 --- heat/core/linalg/solver.py | 42 +++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 383103321c..d64f932a62 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -285,7 +285,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: raise RuntimeError("A needs to be a square matrix.") if not (b.split == 0 or b.split is None): raise RuntimeError("split=1 is not allowed for the right hand side.") - if not b.shape[0] == A.shape[0]: + if not b.shape[0] == A.shape[1]: raise RuntimeError("Dimension mismath of A and b.") if ( A.split is not None @@ -294,36 +294,46 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: ): raise RuntimeError("Local arrays of A and b have different sizes.") + comm = A.comm + dev = A.device + if A.split is None: x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) - return factories.array(x, dtype=b.dtype, device=b.device, comm=b.comm) + return factories.array(x, dtype=b.dtype, device=dev, comm=comm) - nprocs = A.comm.Get_size() + nprocs = comm.Get_size() A_lshapes_cum = torch.hstack( [torch.zeros(1, dtype=torch.int32), torch.cumsum(A.lshape_map[:, A.split], 0)] ) btilde_loc = b.larray.clone() - x = factories.zeros_like(b, comm=b.comm) + x = factories.zeros_like(b, comm=comm) if A.split == 1: - for i in range(nprocs - 1, -1, -1): - res = torch.zeros( - (A_lshapes_cum[i], b.shape[1]), - dtype=b.dtype.torch_type(), - device=b.device.torch_device, - ) - if A.comm.rank == i: + for i in range(nprocs - 1, 0, -1): + if comm.rank == i: x.larray = torch.linalg.solve_triangular( A.larray[A_lshapes_cum[i] : A_lshapes_cum[i + 1], :], btilde_loc, upper=True ) res = A.larray[: A_lshapes_cum[i], :] @ x.larray - if i > 0: - req = A.comm.Ibcast(res, root=i) - req.Wait() - if A.comm.rank < i: - j = A.comm.rank + else: + res = torch.zeros( + (A_lshapes_cum[i], b.shape[1]), + dtype=b.dtype.torch_type(), + device=b.device.torch_device, + ) + + req = comm.Ibcast(res, root=i) # why not Bcast? + req.Wait() + + if comm.rank < i: + j = comm.rank btilde_loc -= res[A_lshapes_cum[j] : A_lshapes_cum[j + 1], :] + if comm.rank == 0: + x.larray = torch.linalg.solve_triangular( + A.larray[: A_lshapes_cum[1], :], btilde_loc, upper=True + ) + # if A.split == 1: # for i in range(nprocs-1,-1,-1): # count = b.lshape[0]*b.lshape[1] From de2fe7390ddb55002d024a213a983c4068c3c06c Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Wed, 17 Jan 2024 14:38:02 +0100 Subject: [PATCH 05/21] small improvements in solve_triangular for the case A.split == 0 --- heat/core/linalg/solver.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index d64f932a62..f97a00671d 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -296,6 +296,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: comm = A.comm dev = A.device + tdev = dev.torch_device if A.split is None: x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) @@ -319,11 +320,10 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: res = torch.zeros( (A_lshapes_cum[i], b.shape[1]), dtype=b.dtype.torch_type(), - device=b.device.torch_device, + device=tdev, ) - req = comm.Ibcast(res, root=i) # why not Bcast? - req.Wait() + comm.Bcast(res, root=i) if comm.rank < i: j = comm.rank @@ -350,22 +350,28 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: # btilde_loc -= res_recv.reshape(b.lshape) else: - for i in range(nprocs - 1, -1, -1): - x_from_i = torch.zeros( - (x.lshape_map[i, 0], x.lshape_map[i, 1]), - dtype=b.dtype.torch_type(), - device=b.device.torch_device, - ) - if A.comm.rank == i: + for i in range(nprocs - 1, 0, -1): + idims = tuple(x.lshape_map[i]) # broadcasting from node i would be faster + if comm.rank == i: x.larray = torch.linalg.solve_triangular( A.larray[:, A_lshapes_cum[i] : A_lshapes_cum[i + 1]], btilde_loc, upper=True ) x_from_i = x.larray - if i > 0: - req = A.comm.Ibcast(x_from_i, root=i) - req.Wait() - if A.comm.rank < i: - j = A.comm.rank + else: + x_from_i = torch.zeros( + idims, + dtype=b.dtype.torch_type(), + device=tdev, + ) + + comm.Bcast(x_from_i, root=i) + + if comm.rank < i: btilde_loc -= A.larray[:, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i + if comm.rank == 0: + x.larray = torch.linalg.solve_triangular( + A.larray[:, : A_lshapes_cum[1]], btilde_loc, upper=True + ) + return x From 85fd3af44c96c68838932af1b934daf38481026b Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Wed, 17 Jan 2024 16:43:45 +0100 Subject: [PATCH 06/21] added tests for solve_triangular --- heat/core/linalg/tests/test_solver.py | 38 +++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/heat/core/linalg/tests/test_solver.py b/heat/core/linalg/tests/test_solver.py index f8f9889a9d..47149a4e07 100644 --- a/heat/core/linalg/tests/test_solver.py +++ b/heat/core/linalg/tests/test_solver.py @@ -135,3 +135,41 @@ def test_lanczos(self): with self.assertRaises(NotImplementedError): A = ht.random.randn(10, 10, split=1) V, T = ht.lanczos(A, m=3) + + def test_solve_triangular(self): + k = 100 # data dimension size + torch.manual_seed(42) + + # random triangular matrix inversion + at = torch.rand((k, k)) + # at += torch.eye(k) + at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable + at = torch.triu(at) + + ct = torch.linalg.solve_triangular(at, torch.eye(k), upper=True) + + a = ht.factories.asarray(at, copy=True) + c = ht.factories.asarray(ct, copy=True) + b = ht.eye(k) + + for s0, s1 in (None, None), (0, 0), (1, 0): + a.resplit_(s0) + b.resplit_(s1) + + res = ht.linalg.solve_triangular(a, b) + self.assertTrue(ht.allclose(res, c)) + + # triangular ones inversion + # for this test case, the results should be exact + at = torch.triu(torch.ones_like(at)) + ct = torch.linalg.solve_triangular(at, torch.eye(k), upper=True) + + a = ht.factories.asarray(at, copy=True) + c = ht.factories.asarray(ct, copy=True) + + for s0, s1 in (None, None), (0, 0), (1, 0): + a.resplit_(s0) + b.resplit_(s1) + + res = ht.linalg.solve_triangular(a, b) + self.assertTrue(ht.equal(res, c)) From aa27c2c89b2c9aa5d27d3bc58ae03a0a04f7786d Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 6 Feb 2024 13:53:02 +0100 Subject: [PATCH 07/21] implemented the case A.split == 1 with scatterv instead of bcast to reduce the communication needed --- heat/core/linalg/solver.py | 39 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index f97a00671d..cca955072a 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -8,6 +8,7 @@ from .. import factories import torch +from mpi4py import MPI __all__ = ["cg", "lanczos", "solve_triangular"] @@ -311,47 +312,33 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: if A.split == 1: for i in range(nprocs - 1, 0, -1): + count = (b.lshape_map[:, 0] * b.lshape_map[:, 1]).numpy() + displ = (A_lshapes_cum * b.shape[1]).numpy()[:-1] + count[i:] = 0 # nothing to send, as there are only zero rows + displ[i:] = 0 + + res_send = None + res_recv = torch.zeros(count[comm.rank], dtype=torch.double, device=tdev) + if comm.rank == i: x.larray = torch.linalg.solve_triangular( A.larray[A_lshapes_cum[i] : A_lshapes_cum[i + 1], :], btilde_loc, upper=True ) - res = A.larray[: A_lshapes_cum[i], :] @ x.larray - else: - res = torch.zeros( - (A_lshapes_cum[i], b.shape[1]), - dtype=b.dtype.torch_type(), - device=tdev, - ) + res_send = (A.larray @ x.larray).flatten().to(torch.double) - comm.Bcast(res, root=i) + comm.handle.Scatterv([res_send, count, displ, MPI.DOUBLE], res_recv, root=i) if comm.rank < i: - j = comm.rank - btilde_loc -= res[A_lshapes_cum[j] : A_lshapes_cum[j + 1], :] + btilde_loc -= res_recv.reshape(b.lshape) if comm.rank == 0: x.larray = torch.linalg.solve_triangular( A.larray[: A_lshapes_cum[1], :], btilde_loc, upper=True ) - # if A.split == 1: - # for i in range(nprocs-1,-1,-1): - # count = b.lshape[0]*b.lshape[1] - # displ = (A_lshapes_cum*b.shape[1]).numpy()[:-1] - # res_send = None - # res_recv = torch.zeros(b.lshape[0]*b.lshape[1], dtype=b.dtype.torch_type(), device=b.device.torch_device) - # if A.comm.rank == i: - # x.larray = torch.linalg.solve_triangular(A.larray[A_lshapes_cum[i]:A_lshapes_cum[i+1],:],btilde_loc,upper=True) - # res_send = (A.larray @ x.larray).flatten() - # if i > 0: - # A.comm.handle.Scatterv([res_send, count, displ, MPI.DOUBLE], res_recv, root=i) - # if A.comm.rank < i: - # j = A.comm.rank - # btilde_loc -= res_recv.reshape(b.lshape) - else: for i in range(nprocs - 1, 0, -1): - idims = tuple(x.lshape_map[i]) # broadcasting from node i would be faster + idims = tuple(x.lshape_map[i]) if comm.rank == i: x.larray = torch.linalg.solve_triangular( A.larray[:, A_lshapes_cum[i] : A_lshapes_cum[i + 1]], btilde_loc, upper=True From 7f2ccc55dd839d53de2c7554818947478f1606cf Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 20 Feb 2024 16:48:59 +0100 Subject: [PATCH 08/21] implemented solve_triangular for batched inputs and b.split = None --- heat/core/linalg/solver.py | 175 +++++++++++++++++++++++-------------- 1 file changed, 110 insertions(+), 65 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index cca955072a..94bc4f9a63 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -8,7 +8,6 @@ from .. import factories import torch -from mpi4py import MPI __all__ = ["cg", "lanczos", "solve_triangular"] @@ -274,91 +273,137 @@ def lanczos( def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: """ - My triangular solver, based on blockwise trisolves... + Solve upper triangular systems of linear equations. + + Input: + A - an upper triangular (possibly batched) invertible square (n x n) matrix + b - (possibly batched) n x k matrix + + Output: + The unique solution x of A * x = b. + + Vectors b have to be given as n x 1 matrices. """ if not isinstance(A, DNDarray) or not isinstance(b, DNDarray): raise RuntimeError("Arguments need to be a DNDarrays.") - if not A.ndim == 2: - raise RuntimeError("A needs to be a 2D matrix") - if not b.ndim <= 2: - raise RuntimeError("b needs to be a vector (1D) or a matrix (2D)") - if not A.shape[0] == A.shape[1]: + if not A.ndim >= 2: + raise RuntimeError("A needs to be a 2D matrix.") + if not b.ndim == A.ndim: + raise RuntimeError("b needs to be a 2D matrix.") + if not A.shape[-2] == A.shape[-1]: raise RuntimeError("A needs to be a square matrix.") - if not (b.split == 0 or b.split is None): + + batch_dim = A.ndim - 2 + batch_shape = A.shape[:batch_dim] + + if not A.shape[:batch_dim] == b.shape[:batch_dim]: + raise RuntimeError("Batch dimensions of A and b must be of the same shape.") + if b.split == batch_dim + 1: raise RuntimeError("split=1 is not allowed for the right hand side.") - if not b.shape[0] == A.shape[1]: + if not b.shape[batch_dim] == A.shape[-1]: raise RuntimeError("Dimension mismath of A and b.") + if ( - A.split is not None - and b.split is not None - and not all(A.lshape_map[:, A.split] == b.lshape_map[:, 0]) - ): - raise RuntimeError("Local arrays of A and b have different sizes.") + A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim + ): # batch split + if A.split != b.split: + raise RuntimeError( + "If a split dimension is a batch dimension, A and b must have the same split dimension." + ) + else: + if ( + A.split is not None and b.split is not None + ): # both la dimensions split --> b.split = batch_dim + if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]): + raise RuntimeError("Local arrays of A and b have different sizes.") comm = A.comm dev = A.device tdev = dev.torch_device - if A.split is None: + if A.split is None: # A not split x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) return factories.array(x, dtype=b.dtype, device=dev, comm=comm) - nprocs = comm.Get_size() - A_lshapes_cum = torch.hstack( - [torch.zeros(1, dtype=torch.int32), torch.cumsum(A.lshape_map[:, A.split], 0)] - ) - btilde_loc = b.larray.clone() - x = factories.zeros_like(b, comm=comm) - - if A.split == 1: - for i in range(nprocs - 1, 0, -1): - count = (b.lshape_map[:, 0] * b.lshape_map[:, 1]).numpy() - displ = (A_lshapes_cum * b.shape[1]).numpy()[:-1] - count[i:] = 0 # nothing to send, as there are only zero rows - displ[i:] = 0 - - res_send = None - res_recv = torch.zeros(count[comm.rank], dtype=torch.double, device=tdev) - - if comm.rank == i: - x.larray = torch.linalg.solve_triangular( - A.larray[A_lshapes_cum[i] : A_lshapes_cum[i + 1], :], btilde_loc, upper=True - ) - res_send = (A.larray @ x.larray).flatten().to(torch.double) + if A.split < batch_dim: # batch split + x = factories.zeros_like(b, comm=comm, split=A.split) + x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) - comm.handle.Scatterv([res_send, count, displ, MPI.DOUBLE], res_recv, root=i) + return x - if comm.rank < i: - btilde_loc -= res_recv.reshape(b.lshape) + nprocs = comm.Get_size() - if comm.rank == 0: - x.larray = torch.linalg.solve_triangular( - A.larray[: A_lshapes_cum[1], :], btilde_loc, upper=True - ) + if A.split >= batch_dim: # both splits in la dims + A_lshapes_cum = torch.hstack( + [torch.zeros(1, dtype=torch.int32), torch.cumsum(A.lshape_map[:, A.split], 0)] + ) - else: - for i in range(nprocs - 1, 0, -1): - idims = tuple(x.lshape_map[i]) - if comm.rank == i: + if b.split is None: + btilde_loc = b.larray[ + ..., A_lshapes_cum[comm.rank] : A_lshapes_cum[comm.rank + 1], : + ].clone() + else: # b is split at la dim 0 + btilde_loc = b.larray.clone() + + x = factories.zeros_like( + b, comm=comm, split=batch_dim + ) # split at la dim 0 in case b is not split + + if A.split == batch_dim + 1: + for i in range(nprocs - 1, 0, -1): + count = x.lshape_map[:, batch_dim].clone().numpy() + displ = A_lshapes_cum[:-1].clone().numpy() + count[i:] = 0 # nothing to send, as there are only zero rows + displ[i:] = 0 + + res_send = torch.empty(0) + res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev) + + if comm.rank == i: + x.larray = torch.linalg.solve_triangular( + A.larray[..., A_lshapes_cum[i] : A_lshapes_cum[i + 1], :], + btilde_loc, + upper=True, + ) + res_send = A.larray @ x.larray + + comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim) + + if comm.rank < i: + btilde_loc -= res_recv + + if comm.rank == 0: x.larray = torch.linalg.solve_triangular( - A.larray[:, A_lshapes_cum[i] : A_lshapes_cum[i + 1]], btilde_loc, upper=True - ) - x_from_i = x.larray - else: - x_from_i = torch.zeros( - idims, - dtype=b.dtype.torch_type(), - device=tdev, + A.larray[..., : A_lshapes_cum[1], :], btilde_loc, upper=True ) - comm.Bcast(x_from_i, root=i) - - if comm.rank < i: - btilde_loc -= A.larray[:, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i - - if comm.rank == 0: - x.larray = torch.linalg.solve_triangular( - A.larray[:, : A_lshapes_cum[1]], btilde_loc, upper=True - ) + else: # split dim is la dim 0 + for i in range(nprocs - 1, 0, -1): + idims = tuple(x.lshape_map[i]) + if comm.rank == i: + x.larray = torch.linalg.solve_triangular( + A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]], + btilde_loc, + upper=True, + ) + x_from_i = x.larray + else: + x_from_i = torch.zeros( + idims, + dtype=b.dtype.torch_type(), + device=tdev, + ) + + comm.Bcast(x_from_i, root=i) + + if comm.rank < i: + btilde_loc -= ( + A.larray[..., :, A_lshapes_cum[i] : A_lshapes_cum[i + 1]] @ x_from_i + ) + + if comm.rank == 0: + x.larray = torch.linalg.solve_triangular( + A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True + ) return x From 8d2af3e815eb6b05fc36c9d7b4fc0f3b7425d895 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 27 Feb 2024 16:19:16 +0100 Subject: [PATCH 09/21] implemented solve_triangular for A.split = None, b.split = -2 --- heat/core/linalg/solver.py | 50 ++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index f2295bc46c..81b2de1d9c 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -322,9 +322,51 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: dev = A.device tdev = dev.torch_device + nprocs = comm.Get_size() + if A.split is None: # A not split - x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) - return factories.array(x, dtype=b.dtype, device=dev, comm=comm) + if b.split is None: + x = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) + + return factories.array(x, dtype=b.dtype, device=dev, comm=comm) + else: # A not split, b.split == -2 + b_lshapes_cum = torch.hstack( + [torch.zeros(1, dtype=torch.int32), torch.cumsum(b.lshape_map[:, -2], 0)] + ) + + btilde_loc = b.larray.clone() + A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]] + + x = factories.zeros_like(b, comm=comm) + + for i in range(nprocs - 1, 0, -1): + count = x.lshape_map[:, batch_dim].clone().numpy() + displ = b_lshapes_cum[:-1].clone().numpy() + count[i:] = 0 # nothing to send, as there are only zero rows + displ[i:] = 0 + + res_send = torch.empty(0) + res_recv = torch.zeros((*batch_shape, count[comm.rank], b.shape[-1]), device=tdev) + + if comm.rank == i: + x.larray = torch.linalg.solve_triangular( + A_loc[..., b_lshapes_cum[i] : b_lshapes_cum[i + 1], :], + btilde_loc, + upper=True, + ) + res_send = A_loc @ x.larray + + comm.Scatterv((res_send, count, displ), res_recv, root=i, axis=batch_dim) + + if comm.rank < i: + btilde_loc -= res_recv + + if comm.rank == 0: + x.larray = torch.linalg.solve_triangular( + A_loc[..., : b_lshapes_cum[1], :], btilde_loc, upper=True + ) + + return x if A.split < batch_dim: # batch split x = factories.zeros_like(b, comm=comm, split=A.split) @@ -332,8 +374,6 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: return x - nprocs = comm.Get_size() - if A.split >= batch_dim: # both splits in la dims A_lshapes_cum = torch.hstack( [torch.zeros(1, dtype=torch.int32), torch.cumsum(A.lshape_map[:, A.split], 0)] @@ -407,4 +447,4 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: A.larray[..., :, : A_lshapes_cum[1]], btilde_loc, upper=True ) - return x + return x From fd7e2d7a17dd87e995fcd6a96e49a130e390201e Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 27 Feb 2024 16:40:24 +0100 Subject: [PATCH 10/21] added tests for batched input of solve_triangular --- heat/core/linalg/tests/test_solver.py | 45 +++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/heat/core/linalg/tests/test_solver.py b/heat/core/linalg/tests/test_solver.py index 47149a4e07..2b427a34fd 100644 --- a/heat/core/linalg/tests/test_solver.py +++ b/heat/core/linalg/tests/test_solver.py @@ -137,8 +137,10 @@ def test_lanczos(self): V, T = ht.lanczos(A, m=3) def test_solve_triangular(self): - k = 100 # data dimension size torch.manual_seed(42) + # non-batched tests + + k = 100 # data dimension size # random triangular matrix inversion at = torch.rand((k, k)) @@ -152,7 +154,7 @@ def test_solve_triangular(self): c = ht.factories.asarray(ct, copy=True) b = ht.eye(k) - for s0, s1 in (None, None), (0, 0), (1, 0): + for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): a.resplit_(s0) b.resplit_(s1) @@ -167,9 +169,46 @@ def test_solve_triangular(self): a = ht.factories.asarray(at, copy=True) c = ht.factories.asarray(ct, copy=True) - for s0, s1 in (None, None), (0, 0), (1, 0): + for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): a.resplit_(s0) b.resplit_(s1) res = ht.linalg.solve_triangular(a, b) self.assertTrue(ht.equal(res, c)) + + # batched tests + + batch_shape = (10,) # batch dimensions shape + # batch_shape = tuple() # no batch dimensions + m = 100 # data dimension size + + at = torch.rand((*batch_shape, m, m)) + # at += torch.eye(k) + at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable + at = torch.triu(at) + bt = torch.eye(m).expand((*batch_shape, -1, -1)) + + ct = torch.linalg.solve_triangular(at, bt, upper=True) + + a = ht.factories.asarray(at, copy=True) + c = ht.factories.asarray(ct, copy=True) + b = ht.factories.asarray(bt, copy=True) + + # split in linalg dimension or none + for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): + a.resplit_(s0) + b.resplit_(s1) + + res = ht.linalg.solve_triangular(a, b) + + self.assertTrue(ht.allclose(c, res)) + + # split in batch dimension + s = 0 + a.resplit_(s) + b.resplit_(s) + c.resplit_(s) + + res = ht.linalg.solve_triangular(a, b) + + self.assertTrue(ht.allclose(c, res)) From 45912601cbd6bfc87db1e41c8ff024298c5c647d Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Thu, 29 Feb 2024 17:17:44 +0100 Subject: [PATCH 11/21] fixed a bug when not using the cpu --- heat/core/linalg/solver.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 81b2de1d9c..d43a07db0c 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -340,8 +340,8 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: x = factories.zeros_like(b, comm=comm) for i in range(nprocs - 1, 0, -1): - count = x.lshape_map[:, batch_dim].clone().numpy() - displ = b_lshapes_cum[:-1].clone().numpy() + count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy() + displ = b_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy() count[i:] = 0 # nothing to send, as there are only zero rows displ[i:] = 0 @@ -369,7 +369,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: return x if A.split < batch_dim: # batch split - x = factories.zeros_like(b, comm=comm, split=A.split) + x = factories.zeros_like(b, device=dev, comm=comm, split=A.split) x.larray = torch.linalg.solve_triangular(A.larray, b.larray, upper=True) return x @@ -392,8 +392,8 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: if A.split == batch_dim + 1: for i in range(nprocs - 1, 0, -1): - count = x.lshape_map[:, batch_dim].clone().numpy() - displ = A_lshapes_cum[:-1].clone().numpy() + count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy() + displ = A_lshapes_cum[:-1].to(torch.device("cpu")).clone().numpy() count[i:] = 0 # nothing to send, as there are only zero rows displ[i:] = 0 From cf84b8801f15f0c57b80e882779ee78417d8efbc Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Thu, 29 Feb 2024 17:58:51 +0100 Subject: [PATCH 12/21] specify device --- heat/core/linalg/solver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index d43a07db0c..a9ee4b6a12 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -337,7 +337,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: btilde_loc = b.larray.clone() A_loc = A.larray[..., b_lshapes_cum[comm.rank] : b_lshapes_cum[comm.rank + 1]] - x = factories.zeros_like(b, comm=comm) + x = factories.zeros_like(b, device=dev, comm=comm) for i in range(nprocs - 1, 0, -1): count = x.lshape_map[:, batch_dim].to(torch.device("cpu")).clone().numpy() @@ -387,7 +387,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: btilde_loc = b.larray.clone() x = factories.zeros_like( - b, comm=comm, split=batch_dim + b, device=dev, comm=comm, split=batch_dim ) # split at la dim 0 in case b is not split if A.split == batch_dim + 1: From 18d3f00fd2ff94bf803206f54402db220ace6ff6 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 5 Mar 2024 16:38:56 +0100 Subject: [PATCH 13/21] fixed on gpu --- heat/core/linalg/solver.py | 10 ++++++++-- heat/core/linalg/tests/test_solver.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index a9ee4b6a12..64946ec9e0 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -331,7 +331,10 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: return factories.array(x, dtype=b.dtype, device=dev, comm=comm) else: # A not split, b.split == -2 b_lshapes_cum = torch.hstack( - [torch.zeros(1, dtype=torch.int32), torch.cumsum(b.lshape_map[:, -2], 0)] + [ + torch.zeros(1, dtype=torch.int32, device=tdev), + torch.cumsum(b.lshape_map[:, -2], 0), + ] ) btilde_loc = b.larray.clone() @@ -376,7 +379,10 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: if A.split >= batch_dim: # both splits in la dims A_lshapes_cum = torch.hstack( - [torch.zeros(1, dtype=torch.int32), torch.cumsum(A.lshape_map[:, A.split], 0)] + [ + torch.zeros(1, dtype=torch.int32, device=tdev), + torch.cumsum(A.lshape_map[:, A.split], 0), + ] ) if b.split is None: diff --git a/heat/core/linalg/tests/test_solver.py b/heat/core/linalg/tests/test_solver.py index 2b427a34fd..5d06d7a877 100644 --- a/heat/core/linalg/tests/test_solver.py +++ b/heat/core/linalg/tests/test_solver.py @@ -138,17 +138,18 @@ def test_lanczos(self): def test_solve_triangular(self): torch.manual_seed(42) - # non-batched tests + tdev = ht.get_device().torch_device + # non-batched tests k = 100 # data dimension size # random triangular matrix inversion at = torch.rand((k, k)) # at += torch.eye(k) at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable - at = torch.triu(at) + at = torch.triu(at).to(tdev) - ct = torch.linalg.solve_triangular(at, torch.eye(k), upper=True) + ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True) a = ht.factories.asarray(at, copy=True) c = ht.factories.asarray(ct, copy=True) @@ -163,8 +164,8 @@ def test_solve_triangular(self): # triangular ones inversion # for this test case, the results should be exact - at = torch.triu(torch.ones_like(at)) - ct = torch.linalg.solve_triangular(at, torch.eye(k), upper=True) + at = torch.triu(torch.ones_like(at)).to(tdev) + ct = torch.linalg.solve_triangular(at, torch.eye(k, device=tdev), upper=True) a = ht.factories.asarray(at, copy=True) c = ht.factories.asarray(ct, copy=True) @@ -177,7 +178,6 @@ def test_solve_triangular(self): self.assertTrue(ht.equal(res, c)) # batched tests - batch_shape = (10,) # batch dimensions shape # batch_shape = tuple() # no batch dimensions m = 100 # data dimension size @@ -185,8 +185,8 @@ def test_solve_triangular(self): at = torch.rand((*batch_shape, m, m)) # at += torch.eye(k) at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable - at = torch.triu(at) - bt = torch.eye(m).expand((*batch_shape, -1, -1)) + at = torch.triu(at).to(tdev) + bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev) ct = torch.linalg.solve_triangular(at, bt, upper=True) From bac8fc765de0a763f65c0aaa7572d494dc98e861 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 26 Mar 2024 10:37:37 +0100 Subject: [PATCH 14/21] improved exception handling for invalid input data --- heat/core/linalg/solver.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 64946ec9e0..3630bc7edb 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -286,35 +286,36 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: Vectors b have to be given as n x 1 matrices. """ if not isinstance(A, DNDarray) or not isinstance(b, DNDarray): - raise RuntimeError("Arguments need to be a DNDarrays.") + raise TypeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.") if not A.ndim >= 2: - raise RuntimeError("A needs to be a 2D matrix.") + raise ValueError("A needs to be a (batched) matrix.") if not b.ndim == A.ndim: - raise RuntimeError("b needs to be a 2D matrix.") + raise ValueError("b needs to have the same number of (batch) dimensions as A.") if not A.shape[-2] == A.shape[-1]: - raise RuntimeError("A needs to be a square matrix.") + raise ValueError("A needs to be a (batched) square matrix.") batch_dim = A.ndim - 2 batch_shape = A.shape[:batch_dim] if not A.shape[:batch_dim] == b.shape[:batch_dim]: - raise RuntimeError("Batch dimensions of A and b must be of the same shape.") + raise ValueError("Batch dimensions of A and b must be of the same shape.") if b.split == batch_dim + 1: - raise RuntimeError("split=1 is not allowed for the right hand side.") + raise ValueError("split=1 is not allowed for the right hand side.") if not b.shape[batch_dim] == A.shape[-1]: - raise RuntimeError("Dimension mismath of A and b.") + raise ValueError("Dimension mismatch of A and b.") if ( A.split is not None and A.split < batch_dim or b.split is not None and b.split < batch_dim ): # batch split if A.split != b.split: - raise RuntimeError( - "If a split dimension is a batch dimension, A and b must have the same split dimension." + raise ValueError( + "If a split dimension is a batch dimension, A and b must have the same split dimension. A possible solution would be a resplit of A or b to the same split dimension." ) else: if ( A.split is not None and b.split is not None ): # both la dimensions split --> b.split = batch_dim + # TODO remove? if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]): raise RuntimeError("Local arrays of A and b have different sizes.") From 22507ae77498b616da01b65b85cb51f3c52de77e Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Tue, 26 Mar 2024 10:54:15 +0100 Subject: [PATCH 15/21] added tests for higher dimensional batches --- heat/core/linalg/tests/test_solver.py | 58 ++++++++++++++++----------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/heat/core/linalg/tests/test_solver.py b/heat/core/linalg/tests/test_solver.py index 5d06d7a877..a4c473ac8f 100644 --- a/heat/core/linalg/tests/test_solver.py +++ b/heat/core/linalg/tests/test_solver.py @@ -178,37 +178,47 @@ def test_solve_triangular(self): self.assertTrue(ht.equal(res, c)) # batched tests - batch_shape = (10,) # batch dimensions shape - # batch_shape = tuple() # no batch dimensions + batch_shapes = [ + (10,), + ( + 4, + 4, + 4, + 20, + ), + ] m = 100 # data dimension size - at = torch.rand((*batch_shape, m, m)) - # at += torch.eye(k) - at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable - at = torch.triu(at).to(tdev) - bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev) + for batch_shape in batch_shapes: + # batch_shape = tuple() # no batch dimensions - ct = torch.linalg.solve_triangular(at, bt, upper=True) + at = torch.rand((*batch_shape, m, m)) + # at += torch.eye(k) + at += 1e2 * torch.ones_like(at) # make gaussian elimination more stable + at = torch.triu(at).to(tdev) + bt = torch.eye(m).expand((*batch_shape, -1, -1)).to(tdev) - a = ht.factories.asarray(at, copy=True) - c = ht.factories.asarray(ct, copy=True) - b = ht.factories.asarray(bt, copy=True) + ct = torch.linalg.solve_triangular(at, bt, upper=True) - # split in linalg dimension or none - for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): - a.resplit_(s0) - b.resplit_(s1) + a = ht.factories.asarray(at, copy=True) + c = ht.factories.asarray(ct, copy=True) + b = ht.factories.asarray(bt, copy=True) - res = ht.linalg.solve_triangular(a, b) + # split in linalg dimension or none + for s0, s1 in (None, None), (-2, -2), (-1, -2), (-2, None), (-1, None), (None, -2): + a.resplit_(s0) + b.resplit_(s1) - self.assertTrue(ht.allclose(c, res)) + res = ht.linalg.solve_triangular(a, b) - # split in batch dimension - s = 0 - a.resplit_(s) - b.resplit_(s) - c.resplit_(s) + self.assertTrue(ht.allclose(c, res)) - res = ht.linalg.solve_triangular(a, b) + # split in batch dimension + s = len(batch_shape) - 1 + a.resplit_(s) + b.resplit_(s) + c.resplit_(s) - self.assertTrue(ht.allclose(c, res)) + res = ht.linalg.solve_triangular(a, b) + + self.assertTrue(ht.allclose(c, res)) From bfac0c339a27f3d72b4bee26a12b9652a74904ff Mon Sep 17 00:00:00 2001 From: Hoppe Date: Wed, 3 Apr 2024 12:48:05 +0200 Subject: [PATCH 16/21] changed error message for non-equal local sizes along split axis --- heat/core/linalg/solver.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 3630bc7edb..68f3c4cc0e 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -317,7 +317,9 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: ): # both la dimensions split --> b.split = batch_dim # TODO remove? if not all(A.lshape_map[:, A.split] == b.lshape_map[:, batch_dim]): - raise RuntimeError("Local arrays of A and b have different sizes.") + raise RuntimeError( + "The process-local arrays of A and b have different sizes along the splitted axis. This is most likely due to one of the DNDarrays being in unbalanced state. \n Consider using `A.is_balanced(force_check=True)` and `b.is_balanced(force_check=True)` to check if A and b are balanced; \n then call `A.balance_()` and/or `b.balance_()` in order to achieve equal local shapes along the split axis before applying `solve_triangular`." + ) comm = A.comm dev = A.device From db370fed67d016914a0b521f903472f1303cc8dd Mon Sep 17 00:00:00 2001 From: Hoppe Date: Wed, 3 Apr 2024 13:03:19 +0200 Subject: [PATCH 17/21] updated docstring according to our conventions --- heat/core/linalg/solver.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 68f3c4cc0e..0a07428a7b 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -274,16 +274,22 @@ def lanczos( def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: """ - Solve upper triangular systems of linear equations. + This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns x in Ax = b, where A is a (possibly batched) upper triangular matrix and + b a (possibly batched) vector or matrix of suitable shape, both provided as input to the function. + The implementation builts on the corresponding solver in PyTorch and implements a block-wise version thereof. - Input: - A - an upper triangular (possibly batched) invertible square (n x n) matrix - b - (possibly batched) n x k matrix - - Output: - The unique solution x of A * x = b. + Parameters + ---------- + A : DNDarray + An upper triangular (possibly batched) invertible square (n x n) matrix, i.e. an DNDarray of shape (..., n, n). + b : DNDarray + a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A. + Vectors have to be provided as n x 1 matrices and the split dimension of b must the second last dimension if not None. - Vectors b have to be given as n x 1 matrices. + Note + --------- + Since such a check might be computationally expensive, we do not check whether A is indeed upper triangular. + If you require such a check, please open an issue on our GitHub page and request this feature. """ if not isinstance(A, DNDarray) or not isinstance(b, DNDarray): raise TypeError(f"Arguments need to be of type DNDarray, got {type(A)}, {type(b)}.") From af1093768be19fd0def58c12684be04ef3b40562 Mon Sep 17 00:00:00 2001 From: Fabian Hoppe <112093564+mrfh92@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:01:17 +0200 Subject: [PATCH 18/21] Update heat/core/linalg/solver.py Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> --- heat/core/linalg/solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 0a07428a7b..302eeef81c 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -274,7 +274,7 @@ def lanczos( def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: """ - This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns x in Ax = b, where A is a (possibly batched) upper triangular matrix and + This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns `x` in `Ax = b`, where `A` is a (possibly batched) upper triangular matrix and b a (possibly batched) vector or matrix of suitable shape, both provided as input to the function. The implementation builts on the corresponding solver in PyTorch and implements a block-wise version thereof. From 795053f7ad543e16152db2260f33035af15f0dc9 Mon Sep 17 00:00:00 2001 From: Fabian Hoppe <112093564+mrfh92@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:01:35 +0200 Subject: [PATCH 19/21] Update heat/core/linalg/solver.py Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> --- heat/core/linalg/solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 302eeef81c..5fa383c665 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -275,7 +275,7 @@ def lanczos( def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: """ This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns `x` in `Ax = b`, where `A` is a (possibly batched) upper triangular matrix and - b a (possibly batched) vector or matrix of suitable shape, both provided as input to the function. + `b` a (possibly batched) vector or matrix of suitable shape, both provided as input to the function. The implementation builts on the corresponding solver in PyTorch and implements a block-wise version thereof. Parameters From ba3578a96a7bfe40f9bc8470414d7592b8328a58 Mon Sep 17 00:00:00 2001 From: Fabian Hoppe <112093564+mrfh92@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:01:52 +0200 Subject: [PATCH 20/21] Update heat/core/linalg/solver.py Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> --- heat/core/linalg/solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 5fa383c665..ddb33078e7 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -281,7 +281,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: Parameters ---------- A : DNDarray - An upper triangular (possibly batched) invertible square (n x n) matrix, i.e. an DNDarray of shape (..., n, n). + An upper triangular invertible square (n x n) matrix or a batch thereof, i.e. a ``DNDarray`` of shape `(..., n, n)`. b : DNDarray a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A. Vectors have to be provided as n x 1 matrices and the split dimension of b must the second last dimension if not None. From e248a42e39e9999eadf11098eddea77a822f543a Mon Sep 17 00:00:00 2001 From: Fabian Hoppe <112093564+mrfh92@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:16:57 +0200 Subject: [PATCH 21/21] Update solver.py incorporated @ClaudiaComito's suggestions for docstring --- heat/core/linalg/solver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index ddb33078e7..bfa99ee914 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -276,7 +276,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: """ This function provides a solver for (possibly batched) upper triangular systems of linear equations: it returns `x` in `Ax = b`, where `A` is a (possibly batched) upper triangular matrix and `b` a (possibly batched) vector or matrix of suitable shape, both provided as input to the function. - The implementation builts on the corresponding solver in PyTorch and implements a block-wise version thereof. + The implementation builts on the corresponding solver in PyTorch and implements an memory-distributed, MPI-parallel block-wise version thereof. Parameters ---------- @@ -284,7 +284,7 @@ def solve_triangular(A: DNDarray, b: DNDarray) -> DNDarray: An upper triangular invertible square (n x n) matrix or a batch thereof, i.e. a ``DNDarray`` of shape `(..., n, n)`. b : DNDarray a (possibly batched) n x k matrix, i.e. an DNDarray of shape (..., n, k), where the batch-dimensions denoted by ... need to coincide with those of A. - Vectors have to be provided as n x 1 matrices and the split dimension of b must the second last dimension if not None. + (Batched) Vectors have to be provided as ... x n x 1 matrices and the split dimension of b must the second last dimension if not None. Note ---------