Skip to content

Commit

Permalink
broadcast_matmat -> np.vectorize
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Nov 9, 2022
1 parent f1dd252 commit f206691
Showing 1 changed file with 11 additions and 35 deletions.
46 changes: 11 additions & 35 deletions src/probnum/linops/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,17 +1083,10 @@ def broadcast_matmat(
"""Broadcasting for a (implicitly defined) matrix-matrix product.
Convenience function / decorator to broadcast the definition of a matrix-matrix
product to vectors. This can be used to easily construct a new linear operator
only from a matrix-matrix product.
product to stacks of matrices. This can be used to easily construct a new linear
operator only from a matrix-matrix product.
"""

def _matmul(x: np.ndarray) -> np.ndarray:
if x.ndim == 2:
return matmat(x)

return _apply_to_matrix_stack(matmat, x)

return _matmul
return np.vectorize(matmat, signature="(n,k)->(m,k)")

@property
def _inexact_dtype(self) -> np.dtype:
Expand All @@ -1103,27 +1096,6 @@ def _inexact_dtype(self) -> np.dtype:
return np.double


def _apply_to_matrix_stack(
mat_fn: Callable[[np.ndarray], np.ndarray], x: np.ndarray
) -> np.ndarray:
idcs = np.ndindex(x.shape[:-2])

# Shape and dtype inference
idx0 = next(idcs)
y0 = mat_fn(x[idx0])

# Result buffer
y = np.empty(x.shape[:-2] + y0.shape, dtype=y0.dtype)

# Fill buffer
y[idx0] = y0

for idx in idcs:
y[idx] = mat_fn(x[idx])

return y


def _call_if_implemented(method: Optional[callable]) -> callable:
if method is not None:
return method
Expand Down Expand Up @@ -1347,15 +1319,19 @@ def __init__(self, linop: LinearOperator):

self._linop = linop

solve = np.vectorize(
self._solve,
excluded=("trans",),
signature="(n, k)->(n, k)",
)

super().__init__(
shape=self._linop.shape,
dtype=self._linop._inexact_dtype,
matmul=LinearOperator.broadcast_matmat(self._solve),
matmul=solve,
transpose=lambda: TransposedLinearOperator(
self,
matmul=LinearOperator.broadcast_matmat(
lambda x: self._solve(x, trans=True)
),
matmul=lambda x: solve(x, trans=True),
),
inverse=lambda: self._linop,
det=lambda: 1 / self._linop.det(),
Expand Down

0 comments on commit f206691

Please sign in to comment.