diff --git a/pytorch_pfn_extras/__init__.py b/pytorch_pfn_extras/__init__.py index 30379c702..ac080c454 100644 --- a/pytorch_pfn_extras/__init__.py +++ b/pytorch_pfn_extras/__init__.py @@ -28,4 +28,5 @@ from pytorch_pfn_extras.runtime._to import to # NOQA if requires("2.0.0"): + from pytorch_pfn_extras import ops # NOQA from pytorch_pfn_extras._dynamo import compile # NOQA diff --git a/pytorch_pfn_extras/ops/__init__.py b/pytorch_pfn_extras/ops/__init__.py new file mode 100644 index 000000000..56af0178f --- /dev/null +++ b/pytorch_pfn_extras/ops/__init__.py @@ -0,0 +1 @@ +from pytorch_pfn_extras.ops.register import OpDesc, register # NOQA diff --git a/pytorch_pfn_extras/ops/register.py b/pytorch_pfn_extras/ops/register.py new file mode 100644 index 000000000..45a932b58 --- /dev/null +++ b/pytorch_pfn_extras/ops/register.py @@ -0,0 +1,91 @@ +from typing import Any, Callable, cast + +import torch +import torch.library + +# Libraries used to store the ops definitions +library = torch.library.Library("ppe", "DEF") +library_impl = torch.library.Library("ppe", "IMPL", "CompositeExplicitAutograd") +library_autograd_impl = torch.library.Library("ppe", "IMPL", "Autograd") +library_meta_impl = torch.library.Library("ppe", "IMPL", "Meta") + + +class OpDesc: + """Metadata to register an op to torch.library. + + Attributes: + op (callable): code to be executed in the forward/backward of the op. + meta (callable): function to perform shape inference for forward/backward + passes. + signature (str): Arguments and return type of the function + ``"(Tensor a, Tensor b) -> Tensor[]"``. + """ + + def __init__( + self, + op: Callable[..., Any], + meta: Callable[..., Any], + signature: str, + ) -> None: + self.op = op + self.meta = meta + self.signature = signature + + +def _get_autograd(name: str) -> Callable[..., Any]: + class RunBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, *args, **kwargs): # type: ignore[no-untyped-def] + ctx.save_for_backward(*args) + op_h = torch._C._dispatch_find_schema_or_throw( + f"ppe::{name}_fwd", "" + ) + return torch._C._dispatch_call_boxed(op_h, *args, **kwargs) + + @staticmethod + def backward(ctx, *args): # type: ignore[no-untyped-def] + i_args = tuple(ctx.saved_tensors) + op_h = torch._C._dispatch_find_schema_or_throw( + f"ppe::{name}_bwd", "" + ) + return torch._C._dispatch_call_boxed(op_h, *(args + i_args), **{}) + + return cast(Callable[..., Any], RunBackward.apply) + + +def register( + name: str, + fwd_op: OpDesc, + bwd_op: OpDesc, +) -> None: + """ + Register a custom op under ``torch.ops.ppe.name`` + + The function appears as a primitive op in the forward and backward + ``torch.fx.Graph``s after compiling torch code with `aot_autograd` backend. + Note that for backward functions, all the arguments of the backward pass + together with the forward arguments are passed to it. This means if forward had + ``fwd_op(x, y)`` ``x,y`` arguments, the custom bwd_op needs to have a + signature like``bwd_op(grad_output, x, y)`` + + Arguments: + name (str): name of the op, shows how it is registered in ``torch.ops.ppe``. + fwd_op (ppe.ops.OpDesc): code that is executed in the forward pass + bwd_op (ppe.ops.OpDesc): code that is executed in the backward pass + """ + function_sig = f"{name}{fwd_op.signature}" + function_fwd_sig = f"{name}_fwd{fwd_op.signature}" + function_bwd_sig = f"{name}_bwd{bwd_op.signature}" + for s in (function_sig, function_fwd_sig, function_bwd_sig): + library.define(s) + + def function(*args): # type: ignore[no-untyped-def] + op_h = torch._C._dispatch_find_schema_or_throw(f"ppe::{name}_fwd", "") + return torch._C._dispatch_call_boxed(op_h, *args, **{}) + + library_impl.impl(name, function) + library_impl.impl(f"{name}_fwd", fwd_op.op) + library_impl.impl(f"{name}_bwd", bwd_op.op) + library_meta_impl.impl(f"{name}_fwd", fwd_op.meta) + library_meta_impl.impl(f"{name}_bwd", bwd_op.meta) + library_autograd_impl.impl(name, _get_autograd(name)) diff --git a/stubs/torch/_C/__init__.pyi b/stubs/torch/_C/__init__.pyi index fb7e05143..c6caff9d1 100644 --- a/stubs/torch/_C/__init__.pyi +++ b/stubs/torch/_C/__init__.pyi @@ -4073,6 +4073,8 @@ def _activate_cuda_trace() -> None: ... # Defined in torch/csrc/Module.cpp def _current_graph_task_id() -> _int: ... def _current_autograd_node() -> _Node: ... +def _dispatch_find_schema_or_throw(name: str, postfix: str) -> Any: ... +def _dispatch_call_boxed(op: Any, args: Any, kwargs: Any) -> Any: ... class _OutOfMemoryError: pass diff --git a/stubs/torch/library/__init__.pyi b/stubs/torch/library/__init__.pyi new file mode 100644 index 000000000..236aab9a0 --- /dev/null +++ b/stubs/torch/library/__init__.pyi @@ -0,0 +1,9 @@ +# flake8: noqa +from typing import Any, Callable + +class Library: + def __init__(self, ns: str, kind: str, dispatch_key: str = "") -> None: ... + def impl( + self, name: str, fn: Callable[..., Any], dispatch_key: str = "" + ) -> None: ... + def define(self, name: str) -> None: ... diff --git a/tests/pytorch_pfn_extras_tests/test_ops/test_register.py b/tests/pytorch_pfn_extras_tests/test_ops/test_register.py new file mode 100644 index 000000000..cf32eebfd --- /dev/null +++ b/tests/pytorch_pfn_extras_tests/test_ops/test_register.py @@ -0,0 +1,79 @@ +import sys + +import pytest +import pytorch_pfn_extras as ppe +import torch + + +def _get_function_nodes(fx_module): + return [ + node for node in fx_module.graph.nodes if node.op == "call_function" + ] + + +@pytest.mark.skipif( + not ppe.requires("2.1.0") or sys.platform == "win32", + reason="torch custom ops only works for PyTorch>=2.1 and linux", +) +def test_register(): + def test(a): + return a * 2 + + def test_bwd(g, a): + return g + + def test_meta(a): + return torch.empty_like(a) + + def test_bwd_meta(g, a): + return torch.empty_like(a) + + fwd_op = ppe.ops.OpDesc(test, test_meta, "(Tensor a) -> Tensor") + bwd_op = ppe.ops.OpDesc( + test_bwd, test_bwd_meta, "(Tensor g, Tensor a) -> Tensor" + ) + ppe.ops.register("test", fwd_op, bwd_op) + + class TestModule(torch.nn.Module): + def forward(self, a): + # Call the custom function + return torch.ops.ppe.test(a) + + found_fwd_op = False + found_bwd_op = False + + from functorch.compile import make_boxed_func + from torch._dynamo.backends.common import aot_autograd + + # Detect the custom ops + def fwd_compiler_fn(fx_module: torch.fx.GraphModule, _): + nonlocal found_fwd_op + function_nodes = _get_function_nodes(fx_module) + assert len(function_nodes) == 1 + found_fwd_op = ( + function_nodes[0].target is torch.ops.ppe.test_fwd.default + ) + return make_boxed_func(fx_module) + + def bwd_compiler_fn(fx_module: torch.fx.GraphModule, _): + nonlocal found_bwd_op + function_nodes = _get_function_nodes(fx_module) + assert len(function_nodes) == 1 + found_bwd_op = ( + function_nodes[0].target is torch.ops.ppe.test_bwd.default + ) + return make_boxed_func(fx_module) + + aot_backend = aot_autograd( # type: ignore[no-untyped-call] + fw_compiler=fwd_compiler_fn, + bw_compiler=bwd_compiler_fn, + ) + m = TestModule() + torch._dynamo.reset() + module_opt = torch.compile(m, fullgraph=True, backend=aot_backend) + shape = [1, 16, 2048, 128] + x = torch.ones(shape, requires_grad=True) + y = module_opt(x) + y.sum().backward() + assert found_fwd_op + assert found_bwd_op