diff --git a/CHANGELOG.md b/CHANGELOG.md index f6a7d6743ad3..2c4fd48cd159 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * {func}`jax.tree_structure` is deprecated in favor of {func}`jax.tree_util.tree_structure` * {func}`jax.tree_transpose` is deprecated in favor of {func}`jax.tree_util.tree_transpose` * {func}`jax.tree_unflatten` is deprecated in favor of {func}`jax.tree_util.tree_unflatten` + * The `sym_pos` argument of {func}`jax.scipy.linalg.solve` is deprecated in favor of `assume_a='pos'`, + following a similar deprecation in {func}`scipy.linalg.solve`. ## jaxlib 0.3.15 (Unreleased) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 24e405816b5d..8ef4326f0b2e 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -205,9 +205,9 @@ def qr(a, overwrite_a=False, lwork=None, mode="full", pivoting=False, return _qr(a, mode, pivoting) -@partial(jit, static_argnames=('sym_pos', 'lower')) -def _solve(a, b, sym_pos, lower): - if not sym_pos: +@partial(jit, static_argnames=('assume_a', 'lower')) +def _solve(a, b, assume_a, lower): + if assume_a != 'pos': return np_linalg.solve(a, b) a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) @@ -232,9 +232,18 @@ def _solve(a, b, sym_pos, lower): @_wraps(scipy.linalg.solve, lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite')) def solve(a, b, sym_pos=False, lower=False, overwrite_a=False, overwrite_b=False, - debug=False, check_finite=True): + debug=False, check_finite=True, assume_a='gen'): + # TODO(jakevdp) remove sym_pos argument after October 2022 del overwrite_a, overwrite_b, debug, check_finite - return _solve(a, b, sym_pos, lower) + valid_assume_a = ['gen', 'sym', 'her', 'pos'] + if assume_a not in valid_assume_a: + raise ValueError("Expected assume_a to be one of {valid_assume_a}; got {assume_a!r}") + if sym_pos: + warnings.warn("The sym_pos argument to solve() is deprecated and will be removed " + "in a future JAX release. Use assume_a='pos' instead.", + category=FutureWarning, stacklevel=2) + assume_a = 'pos' + return _solve(a, b, assume_a, lower) @partial(jit, static_argnames=('trans', 'lower', 'unit_diagonal')) def _solve_triangular(a, b, trans, lower, unit_diagonal): diff --git a/jax/_src/scipy/sparse/linalg.py b/jax/_src/scipy/sparse/linalg.py index 275a88fa9108..e9fddd05683c 100644 --- a/jax/_src/scipy/sparse/linalg.py +++ b/jax/_src/scipy/sparse/linalg.py @@ -512,7 +512,7 @@ def _lstsq(a, b): # faster than jsp.linalg.lstsq a2 = _dot(a.T.conj(), a) b2 = _dot(a.T.conj(), b) - return jsp.linalg.solve(a2, b2, sym_pos=True) + return jsp.linalg.solve(a2, b2, assume_a='pos') def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M): diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 909b08c448e4..0368c315758c 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1153,31 +1153,31 @@ def args_maker(): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": - "_lhs={}_rhs={}_sym_pos={}_lower={}".format( + "_lhs={}_rhs={}_assume_a={}_lower={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), - sym_pos, lower), + assume_a, lower), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, - "sym_pos": sym_pos, "lower": lower} + "assume_a": assume_a, "lower": lower} for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4,)), ((8, 8), (8, 4)), ] - for sym_pos, lower in [ - (False, False), - (True, False), - (True, True), + for assume_a, lower in [ + ('gen', False), + ('pos', False), + ('pos', True), ] for dtype in float_types + complex_types)) - def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower): + def testSolve(self, lhs_shape, rhs_shape, dtype, assume_a, lower): rng = jtu.rand_default(self.rng()) - osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower) - jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower) + osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower) + jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, assume_a=assume_a, lower=lower) def args_maker(): a = rng(lhs_shape, dtype) - if sym_pos: + if assume_a == 'pos': a = np.matmul(a, np.conj(T(a))) a = np.tril(a) if lower else np.triu(a) return [a, rng(rhs_shape, dtype)]