Skip to content

Commit

Permalink
Merge pull request jax-ml#11546 from jakevdp:fix-scipy-sym-pos
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 461976166
  • Loading branch information
jax authors committed Jul 19, 2022
2 parents 9f914a9 + 9090dd1 commit 6ea9e4d
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* {func}`jax.tree_structure` is deprecated in favor of {func}`jax.tree_util.tree_structure`
* {func}`jax.tree_transpose` is deprecated in favor of {func}`jax.tree_util.tree_transpose`
* {func}`jax.tree_unflatten` is deprecated in favor of {func}`jax.tree_util.tree_unflatten`
* The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`,
following a similar deprecation in {func}`scipy.linalg.solve`.

## jaxlib 0.3.15 (Unreleased)

Expand Down
19 changes: 14 additions & 5 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def qr(a, overwrite_a=False, lwork=None, mode="full", pivoting=False,
return _qr(a, mode, pivoting)


@partial(jit, static_argnames=('sym_pos', 'lower'))
def _solve(a, b, sym_pos, lower):
if not sym_pos:
@partial(jit, static_argnames=('assume_a', 'lower'))
def _solve(a, b, assume_a, lower):
if assume_a != 'pos':
return np_linalg.solve(a, b)

a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
Expand All @@ -232,9 +232,18 @@ def _solve(a, b, sym_pos, lower):
@_wraps(scipy.linalg.solve,
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite'))
def solve(a, b, sym_pos=False, lower=False, overwrite_a=False, overwrite_b=False,
debug=False, check_finite=True):
debug=False, check_finite=True, assume_a='gen'):
# TODO(jakevdp) remove sym_pos argument after October 2022
del overwrite_a, overwrite_b, debug, check_finite
return _solve(a, b, sym_pos, lower)
valid_assume_a = ['gen', 'sym', 'her', 'pos']
if assume_a not in valid_assume_a:
raise ValueError("Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}")
if sym_pos:
warnings.warn("The sym_pos argument to solve() is deprecated and will be removed "
"in a future JAX release. Use assume_a='pos' instead.",
category=FutureWarning, stacklevel=2)
assume_a = 'pos'
return _solve(a, b, assume_a, lower)

@partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal'))
def _solve_triangular(a, b, trans, lower, unit_diagonal):
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def _lstsq(a, b):
# faster than jsp.linalg.lstsq
a2 = _dot(a.T.conj(), a)
b2 = _dot(a.T.conj(), b)
return jsp.linalg.solve(a2, b2, sym_pos=True)
return jsp.linalg.solve(a2, b2, assume_a='pos')


def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
Expand Down
22 changes: 11 additions & 11 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,31 +1153,31 @@ def args_maker():

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_lhs={}_rhs={}_sym_pos={}_lower={}".format(
"_lhs={}_rhs={}_assume_a={}_lower={}".format(
jtu.format_shape_dtype_string(lhs_shape, dtype),
jtu.format_shape_dtype_string(rhs_shape, dtype),
sym_pos, lower),
assume_a, lower),
"lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
"sym_pos": sym_pos, "lower": lower}
"assume_a": assume_a, "lower": lower}
for lhs_shape, rhs_shape in [
((1, 1), (1, 1)),
((4, 4), (4,)),
((8, 8), (8, 4)),
]
for sym_pos, lower in [
(False, False),
(True, False),
(True, True),
for assume_a, lower in [
('gen', False),
('pos', False),
('pos', True),
]
for dtype in float_types + complex_types))
def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower):
def testSolve(self, lhs_shape, rhs_shape, dtype, assume_a, lower):
rng = jtu.rand_default(self.rng())
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)
jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower)

def args_maker():
a = rng(lhs_shape, dtype)
if sym_pos:
if assume_a == 'pos':
a = np.matmul(a, np.conj(T(a)))
a = np.tril(a) if lower else np.triu(a)
return [a, rng(rhs_shape, dtype)]
Expand Down

0 comments on commit 6ea9e4d

Please sign in to comment.