Skip to content

Commit

Permalink
Use triangular solves in default _solve
Browse files Browse the repository at this point in the history
Makes the default `_solve` implementation of `LinearOperator` use
triangular solves if the linear operator is lower/upper triangular
Add test cases for triangular matrices
  • Loading branch information
timweiland committed Mar 8, 2023
1 parent 4d435e4 commit aa0b889
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/probnum/linops/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ def _apply(self, x: np.ndarray, axis: int) -> np.ndarray:
return np.moveaxis(self @ np.moveaxis(x, axis, -2), -2, axis)

def __call__(self, x: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
"""Apply the linear operator to an input array along a specified
axis.
"""Apply the linear operator to an input array along a specified axis.
Parameters
----------
Expand Down Expand Up @@ -294,6 +293,10 @@ def _solve(self, B: np.ndarray) -> np.ndarray:
"""
assert B.ndim == 2

if self.is_lower_triangular or self.is_upper_triangular:
return scipy.linalg.solve_triangular(
self.todense(), B, lower=self.is_lower_triangular, trans="N"
)
if self.is_symmetric:
if self.is_positive_definite is not False:
try:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_linops/test_linops_cases/linear_operator_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
np.array([[1.0, -2.0], [-2.0, 5.0]]),
random_spd_matrix(np.random.default_rng(597), dim=10),
]
lower_triangular_matrices = [
np.array([[5.0]]),
np.array([[1.0, 0.0], [2.0, 3.0]]),
np.array([[4.0, 0.0, 0.0], [48.0, 60.0, 0.0], [21.0, 39.0, 7.0]]),
]
upper_triangular_matrices = [mat.T for mat in lower_triangular_matrices]


@pytest.mark.parametrize("matrix", matrices)
Expand Down Expand Up @@ -63,6 +69,28 @@ def case_matrix_spd(matrix: np.ndarray) -> Tuple[pn.linops.LinearOperator, np.nd
return linop, matrix


@pytest_cases.case(tags=("square", "lower-triangular"))
@pytest_cases.parametrize("matrix", lower_triangular_matrices)
def case_matrix_lower_triangular(
matrix: np.ndarray,
) -> Tuple[pn.linops.LinearOperator, np.ndarray]:
linop = pn.linops.Matrix(matrix)
linop.is_lower_triangular = True

return linop, matrix


@pytest_cases.case(tags=("square", "upper-triangular"))
@pytest_cases.parametrize("matrix", upper_triangular_matrices)
def case_matrix_upper_triangular(
matrix: np.ndarray,
) -> Tuple[pn.linops.LinearOperator, np.ndarray]:
linop = pn.linops.Matrix(matrix)
linop.is_upper_triangular = True

return linop, matrix


@pytest_cases.case(tags=("square", "symmetric"))
def case_matrix_symmetric_indefinite() -> Tuple[pn.linops.LinearOperator, np.ndarray]:
matrix = np.diag((2.1, 1.3, -0.5))
Expand Down

0 comments on commit aa0b889

Please sign in to comment.