Skip to content

Commit

Permalink
Implement inplace for Blockwise Ops
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
  • Loading branch information
jessegrabowski and ricardoV94 committed Aug 30, 2024
1 parent 1e9ff57 commit 48721f4
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 39 deletions.
6 changes: 6 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,12 @@ def make_thunk(
)
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)

def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
# TODO: Document this in the Create your own op docs
# By default, do nothing
return self

def __str__(self):
return getattr(type(self), "__name__", super().__str__())

Expand Down
10 changes: 10 additions & 0 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
signature: str | None = None,
name: str | None = None,
gufunc_spec: tuple[str, int, int] | None = None,
destroy_map=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -79,6 +80,15 @@ def __init__(
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
# But we are not using that anywhere yet, so this check is fine for now
raise ValueError(
"Blockwise destroy_map must be the same as that of the core_op"
)

super().__init__(**kwargs)

def __getstate__(self):
Expand Down
82 changes: 80 additions & 2 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import itertools

from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
Expand Down Expand Up @@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):


# We register this rewrite late, so that other rewrites need only target Blockwise Ops
# We do it after position>=70 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb.register(
"local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run",
"fast_compile",
"blockwise",
position=49,
position=60,
)


Expand Down Expand Up @@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op

if blockwise_op.destroy_map:
# Op already has inplace
return

# Find out valid inputs for inplacing
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]

protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
)
]

if not allowed_inplace_inputs:
return None

inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)

if not inplace_core_op.destroy_map:
return None

# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(
"Op destroy_map does not respect allowed_inplace_inputs"
)

# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)

out = inplace_blockwise_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, out)
return out


optdb.register(
"blockwise_inplace",
in2out(blockwise_inplace),
"fast_run",
"inplace",
position=50.1,
)
Loading

0 comments on commit 48721f4

Please sign in to comment.