Skip to content

Commit

Permalink
Generalize Blockwise inplace logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 19, 2024
1 parent 470ea60 commit bb90822
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 56 deletions.
5 changes: 5 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,11 @@ def make_thunk(
)
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)

def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
"""Try to return a version of self that can inplace on candidate_inputs."""
# TODO: Document this in the Create your own op docs
raise NotImplementedError()

def __str__(self):
return getattr(type(self), "__name__", super().__str__())

Expand Down
11 changes: 11 additions & 0 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
signature: Optional[str] = None,
name: Optional[str] = None,
gufunc_spec: Optional[tuple[str, int, int]] = None,
destroy_map=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -94,6 +95,16 @@ def __init__(
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None:
# TODO: Check core_op destroy_map is compatible with Blockwise destroy_map
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
# But we are not using that anywhere yet, so this check is fine for now
raise ValueError(
"Blockwise destroy_map must be the same as that of the core_op"
)

super().__init__(**kwargs)

def __getstate__(self):
Expand Down
12 changes: 0 additions & 12 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,6 @@ def alloc_like(
return rval


def make_inplace(node, inplace_prop="inplace"):
op = getattr(node.op, "core_op", node.op)
props = op._props_dict()
if props[inplace_prop]:
return False

props[inplace_prop] = True
inplace_op = type(op)(**props)

return inplace_op.make_node(*node.inputs).outputs


def register_useless(
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags, **kwargs
):
Expand Down
80 changes: 78 additions & 2 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import itertools
from typing import Optional

from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
Expand Down Expand Up @@ -57,7 +59,7 @@ def local_useless_unbatched_blockwise(fgraph, node):
"fast_run",
"fast_compile",
"blockwise",
position=49,
position=99, # TODO: Check if this makes sense
)


Expand Down Expand Up @@ -199,3 +201,77 @@ def local_blockwise_alloc(fgraph, node):
assert new_outs[0].type.broadcastable == old_out_type.broadcastable
copy_stack_trace(node.outputs, new_outs)
return new_outs


@node_rewriter([Blockwise], inplace=True)
def node_blockwise_inplace(fgraph, node):
# Find inputs that are candidates for inplacing
blockwise_op = node.op

if blockwise_op.destroy_map:
# Op already has inplace
return False

core_op = blockwise_op.core_op
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]

# TODO: Refactor this code, which is also present in Elemwise Inplacer
protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)

# TODO: Add test for the broadcastable logic (don't inplace inputs that are being broadcasted)
candidate_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if (
not isinstance(inp, Constant)
and inp.type.broadcastable[:batch_ndim] == out_batch_bcast
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
)
]

if not candidate_inputs:
return None

try:
inplace_core_op = core_op.try_inplace_inputs(candidate_inputs)
except NotImplementedError:
return False

core_destroy_map = inplace_core_op.destroy_map

if not core_destroy_map:
return False

# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in core_destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in candidate_inputs:
raise ValueError("core_op did not respect candidate inputs")

# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=core_destroy_map,
)

return inplace_blockwise_op.make_node(*node.inputs).outputs


# After destroyhandler(49.5) but before we try to make elemwise things inplace (75)
blockwise_inplace = in2out(node_blockwise_inplace, name="blockwise_inplace")
optdb.register(
"blockwise_inplace",
blockwise_inplace,
"fast_run",
"inplace",
position=69.0,
)
3 changes: 1 addition & 2 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ def apply(self, fgraph):
for i in range(len(node.inputs))
if i not in baseline.values()
and not isinstance(node.inputs[i], Constant)
and
# the next line should not be costly most of the time.
not fgraph.has_destroyers([node.inputs[i]])
and not fgraph.has_destroyers([node.inputs[i]])
and node.inputs[i] not in protected_inputs
]
else:
Expand Down
22 changes: 1 addition & 21 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import logging
from typing import cast

from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.rewriting.basic import (
make_inplace,
register_canonicalize,
register_specialize,
register_stabilize,
Expand Down Expand Up @@ -312,21 +310,3 @@ def local_log_prod_sqr(fgraph, node):

# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.


@node_rewriter([Cholesky], inplace=True)
def local_inplace_cholesky(fgraph, node):
return make_inplace(node, "overwrite_a")


# After destroyhandler(49.5) but before we try to make elemwise things
# inplace (75)
linalg_opt_inplace = in2out(local_inplace_cholesky, name="linalg_opt_inplace")
optdb.register(
"InplaceLinalgOpt",
linalg_opt_inplace,
"fast_run",
"inplace",
"linalg_opt_inplace",
position=69.0,
)
55 changes: 55 additions & 0 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def conjugate_solve_triangular(outer, inner):
else:
return [grad]

def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
if candidate_inputs == [0]:
return type(self)(
lower=self.lower, overwrite_a=True, on_error=self.on_error
)


def cholesky(x, lower=True, on_error="raise", overwrite_a=False):
return Blockwise(Cholesky(lower=lower, on_error=on_error, overwrite_a=overwrite_a))(
Expand All @@ -155,6 +161,8 @@ class SolveBase(Op):
"lower",
"check_finite",
"b_ndim",
"overwrite_a",
"overwrite_b",
)

def __init__(
Expand All @@ -163,6 +171,8 @@ def __init__(
lower=False,
check_finite=True,
b_ndim,
overwrite_a=False,
overwrite_b=False,
):
self.lower = lower
self.check_finite = check_finite
Expand All @@ -172,6 +182,16 @@ def __init__(
self.gufunc_signature = "(m,m),(m)->(m)"
else:
self.gufunc_signature = "(m,m),(m,n)->(m,n)"
self.overwrite_a = overwrite_a
self.overwrite_b = overwrite_b
destroy_map = {}
if self.overwrite_a and self.overwrite_b:
destroy_map[0] = [0, 1]
elif self.overwrite_a:
destroy_map[0] = [0]
elif self.overwrite_b:
destroy_map[0] = [1]
self.destroy_map = destroy_map

def perform(self, node, inputs, outputs):
pass
Expand Down Expand Up @@ -245,7 +265,16 @@ def _default_b_ndim(b, b_ndim):


class CholeskySolve(SolveBase):
__props__ = (
"lower",
"check_finite",
"b_ndim",
"overwrite_b",
)

def __init__(self, **kwargs):
if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for CholeskySolve")
kwargs.setdefault("lower", True)
super().__init__(**kwargs)

Expand All @@ -260,8 +289,15 @@ def perform(self, node, inputs, output_storage):
output_storage[0][0] = rval

def L_op(self, *args, **kwargs):
# TODO: Base impl should work, let's try it
raise NotImplementedError()

def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
if 1 in candidate_inputs:
new_props = self._props_dict()
new_props["overwrite_b"] = True
return type(self)(**new_props)


def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: Optional[int] = None):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Expand Down Expand Up @@ -296,9 +332,12 @@ class SolveTriangular(SolveBase):
"lower",
"check_finite",
"b_ndim",
"overwrite_b",
)

def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
if kwargs.get("overwrite_a", False):
raise ValueError("overwrite_a is not supported for SolverTriangulare")
super().__init__(**kwargs)
self.trans = trans
self.unit_diagonal = unit_diagonal
Expand All @@ -324,6 +363,12 @@ def L_op(self, inputs, outputs, output_gradients):

return res

def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
if 1 in candidate_inputs:
new_props = self._props_dict()
new_props["overwrite_b"] = True
return type(self)(**new_props)


def solve_triangular(
a: TensorVariable,
Expand Down Expand Up @@ -383,6 +428,8 @@ class Solve(SolveBase):
"lower",
"check_finite",
"b_ndim",
"overwrite_a",
"overwrite_b",
)

def __init__(self, *, assume_a="gen", **kwargs):
Expand All @@ -402,6 +449,14 @@ def perform(self, node, inputs, outputs):
assume_a=self.assume_a,
)

def try_inplace_inputs(self, candidate_inputs: list[int]) -> "Op":
new_props = self._props_dict()
if 0 in candidate_inputs:
new_props["overwrite_a"] = True
if 1 in candidate_inputs:
new_props["overwrite_b"] = True
return type(self)(**new_props)


def solve(
a,
Expand Down
39 changes: 22 additions & 17 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,33 +312,38 @@ def test_invalid_batched_a(self):
config.mode == "FAST_COMPILE",
reason="inplace rewrites disabled when mode is FAST_COMPILE",
)
def test_local_inplace_cholesky():
X = matrix("X")
@pytest.mark.parametrize("is_batched", (False, True))
def test_local_inplace_cholesky(is_batched):
shape = (5, None, None) if is_batched else (None, None)
X = tensor("X", shape=shape)
L = cholesky(X, overwrite_a=False, lower=True)
f = function([pytensor.In(X, mutable=True)], L)

assert not L.owner.op.core_op.overwrite_a

nodes = f.maker.fgraph.toposort()
for node in nodes:
if isinstance(node, Cholesky):
assert node.overwrite_a
break
if is_batched:
[cholesky_op] = [
node.op.core_op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky)
]
else:
[cholesky_op] = [
node.op
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, Cholesky)
]
assert cholesky_op.overwrite_a

X_val = np.random.normal(size=(10, 10)).astype(config.floatX)
X_val_in = X_val @ X_val.T
if is_batched:
X_val_in = np.broadcast_to(X_val_in, (5, *X_val_in.shape)).copy()
X_val_in_copy = X_val_in.copy()

f(X_val_in)

assert_allclose(
X_val_in[np.triu_indices_from(X_val_in, k=1)],
0.0,
atol=1e-4 if config.floatX == "float32" else 1e-8,
rtol=1e-4 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
X_val_in @ X_val_in.T,
X_val_in_copy,
atol=1e-4 if config.floatX == "float32" else 1e-8,
rtol=1e-4 if config.floatX == "float32" else 1e-8,
X_val_in,
np.linalg.cholesky(X_val_in_copy),
)
Loading

0 comments on commit bb90822

Please sign in to comment.