From 0a37de57066da9bf861f1074d2005e2a93aab49d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 6 Jan 2024 22:05:31 +0100 Subject: [PATCH] Inplace Blockwise and core versions of Cholesky and Solve Ops. Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com> --- pytensor/graph/op.py | 6 + pytensor/tensor/blockwise.py | 10 ++ pytensor/tensor/rewriting/blockwise.py | 82 +++++++++- pytensor/tensor/slinalg.py | 203 ++++++++++++++++++++----- tests/tensor/test_blockwise.py | 117 +++++++++++++- tests/tensor/test_slinalg.py | 7 +- 6 files changed, 384 insertions(+), 41 deletions(-) diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 160a65dd7a..684add6308 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -583,6 +583,12 @@ def make_thunk( ) return self.make_py_thunk(node, storage_map, compute_map, no_recycling) + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + """Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`.""" + # TODO: Document this in the Create your own Op docs + # By default, do nothing + return self + def __str__(self): return getattr(type(self), "__name__", super().__str__()) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 08956a0534..8c54e53a98 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -45,6 +45,7 @@ def __init__( signature: str | None = None, name: str | None = None, gufunc_spec: tuple[str, int, int] | None = None, + destroy_map=None, **kwargs, ): """ @@ -79,6 +80,15 @@ def __init__( self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) self.gufunc_spec = gufunc_spec self._gufunc = None + if destroy_map is not None: + self.destroy_map = destroy_map + if self.destroy_map != core_op.destroy_map: + # Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op + # But we are not using that anywhere yet, so this check is fine for now + raise ValueError( + "Blockwise destroy_map must be the same as that of the core_op" + ) + super().__init__(**kwargs) def __getstate__(self): diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 7220824c58..9f3ce9f2e6 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -1,7 +1,10 @@ +import itertools + +from pytensor.compile import Supervisor from pytensor.compile.mode import optdb from pytensor.graph import Constant, node_rewriter from pytensor.graph.replace import vectorize_node -from pytensor.graph.rewriting.basic import copy_stack_trace, out2in +from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.math import Dot @@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node): # We register this rewrite late, so that other rewrites need only target Blockwise Ops +# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops optdb.register( "local_useless_unbatched_blockwise", out2in(local_useless_unbatched_blockwise, ignore_newtrees=True), "fast_run", "fast_compile", "blockwise", - position=49, + position=60, ) @@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node): new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)]) copy_stack_trace(node.outputs[0], new_out) return [new_out] + + +@node_rewriter(tracks=[Blockwise], inplace=True) +def blockwise_inplace(fgraph, node): + blockwise_op = node.op + + if blockwise_op.destroy_map: + # Op already has inplace + return + + # Find out valid inputs for inplacing + batch_ndim = blockwise_op.batch_ndim(node) + out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim] + + protected_inputs = [ + f.protected for f in fgraph._features if isinstance(f, Supervisor) + ] + protected_inputs = list(itertools.chain.from_iterable(protected_inputs)) + protected_inputs.extend(fgraph.outputs) + allowed_inplace_inputs = [ + idx + for idx, inp in enumerate(node.inputs) + if + ( + # Constants would need to be recreated every time if inplaced + not isinstance(inp, Constant) + # We can only inplace on inputs that are not being broadcasted + # As those are reused across iterations of Blockwise + and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast + # Inputs that are marked as protected or destroyed can't be inplaced + and not fgraph.has_destroyers([inp]) + and inp not in protected_inputs + ) + ] + + if not allowed_inplace_inputs: + return None + + inplace_core_op = blockwise_op.core_op.inplace_on_inputs( + allowed_inplace_inputs=allowed_inplace_inputs + ) + + if not inplace_core_op.destroy_map: + return None + + # Check Op is not trying to inplace on non-candidate inputs + for destroyed_inputs in inplace_core_op.destroy_map.values(): + for destroyed_input in destroyed_inputs: + if destroyed_input not in allowed_inplace_inputs: + raise ValueError( + "Op destroy_map does not respect allowed_inplace_inputs" + ) + + # Recreate core_op with inplace + inplace_blockwise_op = Blockwise( + core_op=inplace_core_op, + signature=blockwise_op.signature, + name=blockwise_op.name, + gufunc_spec=blockwise_op.gufunc_spec, + destroy_map=inplace_core_op.destroy_map, + ) + + out = inplace_blockwise_op.make_node(*node.inputs).outputs + copy_stack_trace(node.outputs, out) + return out + + +optdb.register( + "blockwise_inplace", + in2out(blockwise_inplace), + "fast_run", + "inplace", + position=50.1, +) diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index db8303b2d8..38c7fecb35 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -28,57 +28,68 @@ class Cholesky(Op): - """ - Return a triangular matrix square root of positive semi-definite `x`. - - L = cholesky(X, lower=True) implies dot(L, L.T) == X. - - Parameters - ---------- - lower : bool, default=True - Whether to return the lower or upper cholesky factor - on_error : ['raise', 'nan'] - If on_error is set to 'raise', this Op will raise a - `scipy.linalg.LinAlgError` if the matrix is not positive definite. - If on_error is set to 'nan', it will return a matrix containing - nans instead. - """ - - # TODO: inplace - # TODO: for specific dtypes # TODO: LAPACK wrapper with in-place behavior, for solve also - __props__ = ("lower", "destructive", "on_error") + __props__ = ("lower", "check_finite", "on_error", "overwrite_a") gufunc_signature = "(m,m)->(m,m)" - def __init__(self, *, lower=True, check_finite=True, on_error="raise"): + def __init__( + self, + *, + lower: bool = True, + check_finite: bool = True, + on_error: Literal["raise"] | Literal["nan"] = "raise", + overwrite_a: bool = False, + ): self.lower = lower - self.destructive = False self.check_finite = check_finite if on_error not in ("raise", "nan"): raise ValueError('on_error must be one of "raise" or ""nan"') self.on_error = on_error + self.overwrite_a = overwrite_a + + if self.overwrite_a: + self.destroy_map = {0: [0]} def infer_shape(self, fgraph, node, shapes): return [shapes[0]] def make_node(self, x): x = as_tensor_variable(x) - assert x.ndim == 2 - return Apply(self, [x], [x.type()]) + if x.type.ndim != 2: + raise TypeError( + f"Cholesky only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" + ) + # Call scipy to find output dtype + dtype = scipy.linalg.cholesky(np.eye(1, dtype=x.type.dtype)).dtype + return Apply(self, [x], [tensor(shape=x.type.shape, dtype=dtype)]) def perform(self, node, inputs, outputs): - x = inputs[0] - z = outputs[0] + [x] = inputs + [out] = outputs try: - z[0] = scipy.linalg.cholesky( - x, lower=self.lower, check_finite=self.check_finite - ).astype(x.dtype) + # Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS + # If we have a `C_CONTIGUOUS` array we transpose to benefit from it + if self.overwrite_a and x.flags["C_CONTIGUOUS"]: + out[0] = scipy.linalg.cholesky( + x.T, + lower=not self.lower, + check_finite=self.check_finite, + overwrite_a=True, + ).T + else: + out[0] = scipy.linalg.cholesky( + x, + lower=self.lower, + check_finite=self.check_finite, + overwrite_a=self.overwrite_a, + ) + except scipy.linalg.LinAlgError: if self.on_error == "raise": raise else: - z[0] = (np.zeros(x.shape) * np.nan).astype(x.dtype) + out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype) def L_op(self, inputs, outputs, gradients): """ @@ -131,11 +142,66 @@ def conjugate_solve_triangular(outer, inner): else: return [grad] + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if not allowed_inplace_inputs: + return self + new_props = self._props_dict() # type: ignore + new_props["overwrite_a"] = True + return type(self)(**new_props) -def cholesky(x, lower=True, on_error="raise", check_finite=False): - return Blockwise( - Cholesky(lower=lower, on_error=on_error, check_finite=check_finite) - )(x) + +def cholesky( + x: "TensorLike", + lower: bool = True, + *, + check_finite: bool = False, + overwrite_a: bool = False, + on_error: Literal["raise", "nan"] = "raise", +): + """ + Return a triangular matrix square root of positive semi-definite `x`. + + L = cholesky(X, lower=True) implies dot(L, L.T) == X. + + Parameters + ---------- + x: tensor_like + lower : bool, default=True + Whether to return the lower or upper cholesky factor + check_finite : bool, default=False + Whether to check that the input matrix contains only finite numbers. + overwrite_a: bool, ignored + Whether to use the same memory for the output as `a`. This argument is ignored, and is present here only + for consistency with scipy.linalg.cholesky. + on_error : ['raise', 'nan'] + If on_error is set to 'raise', this Op will raise a `scipy.linalg.LinAlgError` if the matrix is not positive definite. + If on_error is set to 'nan', it will return a matrix containing nans instead. + + Returns + ------- + TensorVariable + Lower or upper triangular Cholesky factor of `x` + + Example + ------- + .. testcode:: + + import pytensor + import pytensor.tensor as pt + import numpy as np + + x = pt.tensor('x', shape=(5, 5), dtype='float64') + L = pt.linalg.cholesky(x) + + f = pytensor.function([x], L) + x_value = np.random.normal(size=(5, 5)) + x_value = x_value @ x_value.T # Ensures x is positive definite + L_value = f(x_value) + assert np.allclose(L_value @ L_value.T, x_value) + + """ + + return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) class SolveBase(Op): @@ -145,6 +211,8 @@ class SolveBase(Op): "lower", "check_finite", "b_ndim", + "overwrite_a", + "overwrite_b", ) def __init__( @@ -153,6 +221,8 @@ def __init__( lower=False, check_finite=True, b_ndim, + overwrite_a=False, + overwrite_b=False, ): self.lower = lower self.check_finite = check_finite @@ -162,9 +232,25 @@ def __init__( self.gufunc_signature = "(m,m),(m)->(m)" else: self.gufunc_signature = "(m,m),(m,n)->(m,n)" + self.overwrite_a = overwrite_a + self.overwrite_b = overwrite_b + destroy_map = {} + if self.overwrite_a and self.overwrite_b: + # An output destroying two inputs is not yet supported + # destroy_map[0] = [0, 1] + raise NotImplementedError( + "It's not yet possible to overwrite_a and overwrite_b symultaneously" + ) + elif self.overwrite_a: + destroy_map[0] = [0] + elif self.overwrite_b: + destroy_map[0] = [1] + self.destroy_map = destroy_map def perform(self, node, inputs, outputs): - pass + raise NotImplementedError( + "SolveBase should be subclassed with an perform method" + ) def make_node(self, A, b): A = as_tensor_variable(A) @@ -235,7 +321,16 @@ def _default_b_ndim(b, b_ndim): class CholeskySolve(SolveBase): + __props__ = ( + "lower", + "check_finite", + "b_ndim", + "overwrite_b", + ) + def __init__(self, **kwargs): + if kwargs.get("overwrite_a", False): + raise ValueError("overwrite_a is not supported for CholeskySolve") kwargs.setdefault("lower", True) super().__init__(**kwargs) @@ -245,13 +340,23 @@ def perform(self, node, inputs, output_storage): (C, self.lower), b, check_finite=self.check_finite, + overwrite_b=self.overwrite_b, ) output_storage[0][0] = rval def L_op(self, *args, **kwargs): + # TODO: Base impl should work, let's try it raise NotImplementedError() + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if 1 in allowed_inplace_inputs: + new_props = self._props_dict() # type: ignore + new_props["overwrite_b"] = True + return type(self)(**new_props) + else: + return self + def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: int | None = None): """Solve the linear equations A x = b, given the Cholesky factorization of A. @@ -286,9 +391,12 @@ class SolveTriangular(SolveBase): "lower", "check_finite", "b_ndim", + "overwrite_b", ) def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): + if kwargs.get("overwrite_a", False): + raise ValueError("overwrite_a is not supported for SolverTriangulare") super().__init__(**kwargs) self.trans = trans self.unit_diagonal = unit_diagonal @@ -302,6 +410,7 @@ def perform(self, node, inputs, outputs): trans=self.trans, unit_diagonal=self.unit_diagonal, check_finite=self.check_finite, + overwrite_b=self.overwrite_b, ) def L_op(self, inputs, outputs, output_gradients): @@ -314,6 +423,14 @@ def L_op(self, inputs, outputs, output_gradients): return res + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if 1 in allowed_inplace_inputs: + new_props = self._props_dict() # type: ignore + new_props["overwrite_b"] = True + return type(self)(**new_props) + else: + return self + def solve_triangular( a: TensorVariable, @@ -374,6 +491,8 @@ class Solve(SolveBase): "lower", "check_finite", "b_ndim", + "overwrite_a", + "overwrite_b", ) def __init__(self, *, assume_a="gen", **kwargs): @@ -391,8 +510,24 @@ def perform(self, node, inputs, outputs): lower=self.lower, check_finite=self.check_finite, assume_a=self.assume_a, + overwrite_a=self.overwrite_a, + overwrite_b=self.overwrite_b, ) + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if not allowed_inplace_inputs: + return self + new_props = self._props_dict() # type: ignore + # PyTensor doesn't allow an output to destroy two inputs yet + # new_props["overwrite_a"] = 0 in allowed_inplace_inputs + # new_props["overwrite_b"] = 1 in allowed_inplace_inputs + if 1 in allowed_inplace_inputs: + # Give preference to overwrite_b + new_props["overwrite_b"] = True + else: # allowed inputs == [0] + new_props["overwrite_a"] = True + return type(self)(**new_props) + def solve( a, diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 43f9b77f4f..f6783cf945 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -3,10 +3,11 @@ import numpy as np import pytest +import scipy.linalg import pytensor -from pytensor import config, function -from pytensor.compile import get_mode +from pytensor import In, config, function +from pytensor.compile import get_default_mode, get_mode from pytensor.gradient import grad from pytensor.graph import Apply, Op from pytensor.graph.replace import vectorize_node @@ -15,7 +16,15 @@ from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot -from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular +from pytensor.tensor.slinalg import ( + Cholesky, + Solve, + SolveBase, + cho_solve, + cholesky, + solve, + solve_triangular, +) from pytensor.tensor.utils import _parse_gufunc_signature @@ -398,3 +407,105 @@ def test_cop_with_params(): with pytest.raises(AssertionError): fn(np.zeros((5, 3, 2)) - 1) + + +@pytest.mark.skipif( + config.mode == "FAST_COMPILE", + reason="inplace rewrites disabled when mode is FAST_COMPILE", +) +class TestInplace: + @pytest.mark.parametrize("is_batched", (False, True)) + def test_cholesky(self, is_batched): + X = tensor("X", shape=(5, None, None) if is_batched else (None, None)) + L = cholesky(X, lower=True) + f = function([In(X, mutable=True)], L) + + assert not L.owner.op.core_op.destroy_map + + if is_batched: + [cholesky_op] = [ + node.op.core_op + for node in f.maker.fgraph.apply_nodes + if isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, Cholesky) + ] + else: + [cholesky_op] = [ + node.op + for node in f.maker.fgraph.apply_nodes + if isinstance(node.op, Cholesky) + ] + assert cholesky_op.destroy_map == {0: [0]} + + rng = np.random.default_rng(441 + is_batched) + X_val = rng.normal(size=(10, 10)).astype(config.floatX) + X_val_in = X_val @ X_val.T + if is_batched: + X_val_in = np.broadcast_to(X_val_in, (5, *X_val_in.shape)).copy() + X_val_in_copy = X_val_in.copy() + + f(X_val_in) + + np.testing.assert_allclose( + X_val_in, + np.linalg.cholesky(X_val_in_copy), + atol=1e-5 if config.floatX == "float32" else 0, + ) + + @pytest.mark.parametrize("batched_A", (False, True)) + @pytest.mark.parametrize("batched_b", (False, True)) + @pytest.mark.parametrize("solve_fn", (solve, solve_triangular, cho_solve)) + def test_solve(self, solve_fn, batched_A, batched_b): + A = tensor("A", shape=(5, 3, 3) if batched_A else (3, 3)) + b = tensor("b", shape=(5, 3) if batched_b else (3,)) + if solve_fn == cho_solve: + # Special signature for cho_solve + x = solve_fn((A, True), b, b_ndim=1) + else: + x = solve_fn(A, b, b_ndim=1) + + mode = get_default_mode().excluding("batched_vector_b_solve_to_matrix_b_solve") + fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode) + + op = fn.maker.fgraph.outputs[0].owner.op + if batched_A or batched_b: + assert isinstance(op, Blockwise) and isinstance(op.core_op, SolveBase) + if batched_A and not batched_b: + if solve_fn == solve: + assert op.destroy_map == {0: [0]} + else: + # SolveTriangular does not destroy A + assert op.destroy_map == {} + else: + assert op.destroy_map == {0: [1]} + else: + assert isinstance(op, SolveBase) + assert op.destroy_map == {0: [1]} + + # We test with an F_CONTIGUOUS (core) A as only that will be destroyed by scipy + rng = np.random.default_rng( + 487 + batched_A + 2 * batched_b + sum(map(ord, solve_fn.__name__)) + ) + A_val = np.swapaxes(rng.normal(size=A.type.shape).astype(A.type.dtype), -1, -2) + b_val = np.random.normal(size=b.type.shape).astype(b.type.dtype) + A_val_copy = A_val.copy() + b_val_copy = b_val.copy() + out = fn(A_val, b_val) + + if solve_fn == cho_solve: + + def core_scipy_fn(A, b): + return scipy.linalg.cho_solve((A, True), b) + + else: + core_scipy_fn = getattr(scipy.linalg, solve_fn.__name__) + expected_out = np.vectorize(core_scipy_fn, signature="(m,m),(m)->(m)")( + A_val_copy, b_val_copy + ) + np.testing.assert_allclose( + out, expected_out, atol=1e-6 if config.floatX == "float32" else 0 + ) + + # Confirm input was destroyed + assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) + assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index e468b56e84..28a0210278 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -197,7 +197,10 @@ def test__repr__(self): A = matrix() b = matrix() y = SolveBase(b_ndim=2)(A, b) - assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0" + assert ( + y.__repr__() + == "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0" + ) class TestSolve(utt.InferShapeTester): @@ -361,7 +364,7 @@ def setup_method(self): def test_repr(self): assert ( repr(CholeskySolve(lower=True, b_ndim=1)) - == "CholeskySolve(lower=True,check_finite=True,b_ndim=1)" + == "CholeskySolve(lower=True,check_finite=True,b_ndim=1,overwrite_b=False)" ) def test_infer_shape(self):