Skip to content

Commit

Permalink
Merge pull request pfnet#796 from emcastillo/custom_torch_ops
Browse files Browse the repository at this point in the history
Add interface for register custom ops in `torch.ops.ppe`
  • Loading branch information
emcastillo authored Dec 7, 2023
2 parents 73082f6 + f1fd387 commit 8e6438c
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 0 deletions.
1 change: 1 addition & 0 deletions pytorch_pfn_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@
from pytorch_pfn_extras.runtime._to import to # NOQA

if requires("2.0.0"):
from pytorch_pfn_extras import ops # NOQA
from pytorch_pfn_extras._dynamo import compile # NOQA
1 change: 1 addition & 0 deletions pytorch_pfn_extras/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytorch_pfn_extras.ops.register import OpDesc, register # NOQA
91 changes: 91 additions & 0 deletions pytorch_pfn_extras/ops/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Any, Callable, cast

import torch
import torch.library

# Libraries used to store the ops definitions
library = torch.library.Library("ppe", "DEF")
library_impl = torch.library.Library("ppe", "IMPL", "CompositeExplicitAutograd")
library_autograd_impl = torch.library.Library("ppe", "IMPL", "Autograd")
library_meta_impl = torch.library.Library("ppe", "IMPL", "Meta")


class OpDesc:
"""Metadata to register an op to torch.library.
Attributes:
op (callable): code to be executed in the forward/backward of the op.
meta (callable): function to perform shape inference for forward/backward
passes.
signature (str): Arguments and return type of the function
``"(Tensor a, Tensor b) -> Tensor[]"``.
"""

def __init__(
self,
op: Callable[..., Any],
meta: Callable[..., Any],
signature: str,
) -> None:
self.op = op
self.meta = meta
self.signature = signature


def _get_autograd(name: str) -> Callable[..., Any]:
class RunBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, *args, **kwargs): # type: ignore[no-untyped-def]
ctx.save_for_backward(*args)
op_h = torch._C._dispatch_find_schema_or_throw(
f"ppe::{name}_fwd", ""
)
return torch._C._dispatch_call_boxed(op_h, *args, **kwargs)

@staticmethod
def backward(ctx, *args): # type: ignore[no-untyped-def]
i_args = tuple(ctx.saved_tensors)
op_h = torch._C._dispatch_find_schema_or_throw(
f"ppe::{name}_bwd", ""
)
return torch._C._dispatch_call_boxed(op_h, *(args + i_args), **{})

return cast(Callable[..., Any], RunBackward.apply)


def register(
name: str,
fwd_op: OpDesc,
bwd_op: OpDesc,
) -> None:
"""
Register a custom op under ``torch.ops.ppe.name``
The function appears as a primitive op in the forward and backward
``torch.fx.Graph``s after compiling torch code with `aot_autograd` backend.
Note that for backward functions, all the arguments of the backward pass
together with the forward arguments are passed to it. This means if forward had
``fwd_op(x, y)`` ``x,y`` arguments, the custom bwd_op needs to have a
signature like``bwd_op(grad_output, x, y)``
Arguments:
name (str): name of the op, shows how it is registered in ``torch.ops.ppe``.
fwd_op (ppe.ops.OpDesc): code that is executed in the forward pass
bwd_op (ppe.ops.OpDesc): code that is executed in the backward pass
"""
function_sig = f"{name}{fwd_op.signature}"
function_fwd_sig = f"{name}_fwd{fwd_op.signature}"
function_bwd_sig = f"{name}_bwd{bwd_op.signature}"
for s in (function_sig, function_fwd_sig, function_bwd_sig):
library.define(s)

def function(*args): # type: ignore[no-untyped-def]
op_h = torch._C._dispatch_find_schema_or_throw(f"ppe::{name}_fwd", "")
return torch._C._dispatch_call_boxed(op_h, *args, **{})

library_impl.impl(name, function)
library_impl.impl(f"{name}_fwd", fwd_op.op)
library_impl.impl(f"{name}_bwd", bwd_op.op)
library_meta_impl.impl(f"{name}_fwd", fwd_op.meta)
library_meta_impl.impl(f"{name}_bwd", bwd_op.meta)
library_autograd_impl.impl(name, _get_autograd(name))
2 changes: 2 additions & 0 deletions stubs/torch/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4073,6 +4073,8 @@ def _activate_cuda_trace() -> None: ...
# Defined in torch/csrc/Module.cpp
def _current_graph_task_id() -> _int: ...
def _current_autograd_node() -> _Node: ...
def _dispatch_find_schema_or_throw(name: str, postfix: str) -> Any: ...
def _dispatch_call_boxed(op: Any, args: Any, kwargs: Any) -> Any: ...

class _OutOfMemoryError:
pass
Expand Down
9 changes: 9 additions & 0 deletions stubs/torch/library/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# flake8: noqa
from typing import Any, Callable

class Library:
def __init__(self, ns: str, kind: str, dispatch_key: str = "") -> None: ...
def impl(
self, name: str, fn: Callable[..., Any], dispatch_key: str = ""
) -> None: ...
def define(self, name: str) -> None: ...
79 changes: 79 additions & 0 deletions tests/pytorch_pfn_extras_tests/test_ops/test_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import sys

import pytest
import pytorch_pfn_extras as ppe
import torch


def _get_function_nodes(fx_module):
return [
node for node in fx_module.graph.nodes if node.op == "call_function"
]


@pytest.mark.skipif(
not ppe.requires("2.1.0") or sys.platform == "win32",
reason="torch custom ops only works for PyTorch>=2.1 and linux",
)
def test_register():
def test(a):
return a * 2

def test_bwd(g, a):
return g

def test_meta(a):
return torch.empty_like(a)

def test_bwd_meta(g, a):
return torch.empty_like(a)

fwd_op = ppe.ops.OpDesc(test, test_meta, "(Tensor a) -> Tensor")
bwd_op = ppe.ops.OpDesc(
test_bwd, test_bwd_meta, "(Tensor g, Tensor a) -> Tensor"
)
ppe.ops.register("test", fwd_op, bwd_op)

class TestModule(torch.nn.Module):
def forward(self, a):
# Call the custom function
return torch.ops.ppe.test(a)

found_fwd_op = False
found_bwd_op = False

from functorch.compile import make_boxed_func
from torch._dynamo.backends.common import aot_autograd

# Detect the custom ops
def fwd_compiler_fn(fx_module: torch.fx.GraphModule, _):
nonlocal found_fwd_op
function_nodes = _get_function_nodes(fx_module)
assert len(function_nodes) == 1
found_fwd_op = (
function_nodes[0].target is torch.ops.ppe.test_fwd.default
)
return make_boxed_func(fx_module)

def bwd_compiler_fn(fx_module: torch.fx.GraphModule, _):
nonlocal found_bwd_op
function_nodes = _get_function_nodes(fx_module)
assert len(function_nodes) == 1
found_bwd_op = (
function_nodes[0].target is torch.ops.ppe.test_bwd.default
)
return make_boxed_func(fx_module)

aot_backend = aot_autograd( # type: ignore[no-untyped-call]
fw_compiler=fwd_compiler_fn,
bw_compiler=bwd_compiler_fn,
)
m = TestModule()
torch._dynamo.reset()
module_opt = torch.compile(m, fullgraph=True, backend=aot_backend)
shape = [1, 16, 2048, 128]
x = torch.ones(shape, requires_grad=True)
y = module_opt(x)
y.sum().backward()
assert found_fwd_op
assert found_bwd_op

0 comments on commit 8e6438c

Please sign in to comment.