forked from pfnet/pytorch-pfn-extras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request pfnet#796 from emcastillo/custom_torch_ops
Add interface for register custom ops in `torch.ops.ppe`
- Loading branch information
Showing
6 changed files
with
183 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from pytorch_pfn_extras.ops.register import OpDesc, register # NOQA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from typing import Any, Callable, cast | ||
|
||
import torch | ||
import torch.library | ||
|
||
# Libraries used to store the ops definitions | ||
library = torch.library.Library("ppe", "DEF") | ||
library_impl = torch.library.Library("ppe", "IMPL", "CompositeExplicitAutograd") | ||
library_autograd_impl = torch.library.Library("ppe", "IMPL", "Autograd") | ||
library_meta_impl = torch.library.Library("ppe", "IMPL", "Meta") | ||
|
||
|
||
class OpDesc: | ||
"""Metadata to register an op to torch.library. | ||
Attributes: | ||
op (callable): code to be executed in the forward/backward of the op. | ||
meta (callable): function to perform shape inference for forward/backward | ||
passes. | ||
signature (str): Arguments and return type of the function | ||
``"(Tensor a, Tensor b) -> Tensor[]"``. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
op: Callable[..., Any], | ||
meta: Callable[..., Any], | ||
signature: str, | ||
) -> None: | ||
self.op = op | ||
self.meta = meta | ||
self.signature = signature | ||
|
||
|
||
def _get_autograd(name: str) -> Callable[..., Any]: | ||
class RunBackward(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, *args, **kwargs): # type: ignore[no-untyped-def] | ||
ctx.save_for_backward(*args) | ||
op_h = torch._C._dispatch_find_schema_or_throw( | ||
f"ppe::{name}_fwd", "" | ||
) | ||
return torch._C._dispatch_call_boxed(op_h, *args, **kwargs) | ||
|
||
@staticmethod | ||
def backward(ctx, *args): # type: ignore[no-untyped-def] | ||
i_args = tuple(ctx.saved_tensors) | ||
op_h = torch._C._dispatch_find_schema_or_throw( | ||
f"ppe::{name}_bwd", "" | ||
) | ||
return torch._C._dispatch_call_boxed(op_h, *(args + i_args), **{}) | ||
|
||
return cast(Callable[..., Any], RunBackward.apply) | ||
|
||
|
||
def register( | ||
name: str, | ||
fwd_op: OpDesc, | ||
bwd_op: OpDesc, | ||
) -> None: | ||
""" | ||
Register a custom op under ``torch.ops.ppe.name`` | ||
The function appears as a primitive op in the forward and backward | ||
``torch.fx.Graph``s after compiling torch code with `aot_autograd` backend. | ||
Note that for backward functions, all the arguments of the backward pass | ||
together with the forward arguments are passed to it. This means if forward had | ||
``fwd_op(x, y)`` ``x,y`` arguments, the custom bwd_op needs to have a | ||
signature like``bwd_op(grad_output, x, y)`` | ||
Arguments: | ||
name (str): name of the op, shows how it is registered in ``torch.ops.ppe``. | ||
fwd_op (ppe.ops.OpDesc): code that is executed in the forward pass | ||
bwd_op (ppe.ops.OpDesc): code that is executed in the backward pass | ||
""" | ||
function_sig = f"{name}{fwd_op.signature}" | ||
function_fwd_sig = f"{name}_fwd{fwd_op.signature}" | ||
function_bwd_sig = f"{name}_bwd{bwd_op.signature}" | ||
for s in (function_sig, function_fwd_sig, function_bwd_sig): | ||
library.define(s) | ||
|
||
def function(*args): # type: ignore[no-untyped-def] | ||
op_h = torch._C._dispatch_find_schema_or_throw(f"ppe::{name}_fwd", "") | ||
return torch._C._dispatch_call_boxed(op_h, *args, **{}) | ||
|
||
library_impl.impl(name, function) | ||
library_impl.impl(f"{name}_fwd", fwd_op.op) | ||
library_impl.impl(f"{name}_bwd", bwd_op.op) | ||
library_meta_impl.impl(f"{name}_fwd", fwd_op.meta) | ||
library_meta_impl.impl(f"{name}_bwd", bwd_op.meta) | ||
library_autograd_impl.impl(name, _get_autograd(name)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# flake8: noqa | ||
from typing import Any, Callable | ||
|
||
class Library: | ||
def __init__(self, ns: str, kind: str, dispatch_key: str = "") -> None: ... | ||
def impl( | ||
self, name: str, fn: Callable[..., Any], dispatch_key: str = "" | ||
) -> None: ... | ||
def define(self, name: str) -> None: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import sys | ||
|
||
import pytest | ||
import pytorch_pfn_extras as ppe | ||
import torch | ||
|
||
|
||
def _get_function_nodes(fx_module): | ||
return [ | ||
node for node in fx_module.graph.nodes if node.op == "call_function" | ||
] | ||
|
||
|
||
@pytest.mark.skipif( | ||
not ppe.requires("2.1.0") or sys.platform == "win32", | ||
reason="torch custom ops only works for PyTorch>=2.1 and linux", | ||
) | ||
def test_register(): | ||
def test(a): | ||
return a * 2 | ||
|
||
def test_bwd(g, a): | ||
return g | ||
|
||
def test_meta(a): | ||
return torch.empty_like(a) | ||
|
||
def test_bwd_meta(g, a): | ||
return torch.empty_like(a) | ||
|
||
fwd_op = ppe.ops.OpDesc(test, test_meta, "(Tensor a) -> Tensor") | ||
bwd_op = ppe.ops.OpDesc( | ||
test_bwd, test_bwd_meta, "(Tensor g, Tensor a) -> Tensor" | ||
) | ||
ppe.ops.register("test", fwd_op, bwd_op) | ||
|
||
class TestModule(torch.nn.Module): | ||
def forward(self, a): | ||
# Call the custom function | ||
return torch.ops.ppe.test(a) | ||
|
||
found_fwd_op = False | ||
found_bwd_op = False | ||
|
||
from functorch.compile import make_boxed_func | ||
from torch._dynamo.backends.common import aot_autograd | ||
|
||
# Detect the custom ops | ||
def fwd_compiler_fn(fx_module: torch.fx.GraphModule, _): | ||
nonlocal found_fwd_op | ||
function_nodes = _get_function_nodes(fx_module) | ||
assert len(function_nodes) == 1 | ||
found_fwd_op = ( | ||
function_nodes[0].target is torch.ops.ppe.test_fwd.default | ||
) | ||
return make_boxed_func(fx_module) | ||
|
||
def bwd_compiler_fn(fx_module: torch.fx.GraphModule, _): | ||
nonlocal found_bwd_op | ||
function_nodes = _get_function_nodes(fx_module) | ||
assert len(function_nodes) == 1 | ||
found_bwd_op = ( | ||
function_nodes[0].target is torch.ops.ppe.test_bwd.default | ||
) | ||
return make_boxed_func(fx_module) | ||
|
||
aot_backend = aot_autograd( # type: ignore[no-untyped-call] | ||
fw_compiler=fwd_compiler_fn, | ||
bw_compiler=bwd_compiler_fn, | ||
) | ||
m = TestModule() | ||
torch._dynamo.reset() | ||
module_opt = torch.compile(m, fullgraph=True, backend=aot_backend) | ||
shape = [1, 16, 2048, 128] | ||
x = torch.ones(shape, requires_grad=True) | ||
y = module_opt(x) | ||
y.sum().backward() | ||
assert found_fwd_op | ||
assert found_bwd_op |