Skip to content

Commit

Permalink
Update usage of PyTorch's custom op API
Browse files Browse the repository at this point in the history
Hi, I maintain the custom ops story in PyTorch. This PR updates the the
usage of PyTorch's private custom op API to a newer API. This API is still
private but closer to what we want it to be.

Test Plan:
- wait for CI
  • Loading branch information
zou3519 committed Sep 22, 2023
1 parent e6e8099 commit 6ba9466
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
17 changes: 6 additions & 11 deletions py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
from typing import Any, Dict, Optional, Sequence, Tuple

import torch
from torch._custom_op.impl import custom_op
import torch._custom_ops as library
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


@custom_op(
qualname="tensorrt::einsum",
manual_schema="(str equation, Tensor[] tensors) -> Tensor",
library.custom_op(
"tensorrt::einsum",
"(str equation, Tensor[] tensors) -> Tensor",
)
def einsum(equation, tensors): # type: ignore[no-untyped-def]
# Defines operator schema, name, namespace, and function header
...


@einsum.impl("cpu") # type: ignore[misc]
@einsum.impl("cuda") # type: ignore[misc]
@einsum.impl_abstract() # type: ignore[misc]
@library.impl("tensorrt::einsum") # type: ignore[misc]
@library.impl_abstract("tensorrt::einsum") # type: ignore[misc]
def einsum_generic(
*args: Any,
**kwargs: Any,
Expand Down
17 changes: 6 additions & 11 deletions py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional, Tuple

import torch
from torch._custom_op.impl import custom_op
import torch._custom_ops as library
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import register_substitution
from torch_tensorrt.fx.converter_registry import tensorrt_converter
Expand All @@ -20,14 +20,10 @@
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
# Then, create a placeholder function with no operations, but having the same schema and naming as that
# used in the decorator
@custom_op(
qualname="tensorrt::maxpool1d",
manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor",
library.custom_op(
"tensorrt::maxpool1d",
"(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor"
)
def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ignore[no-untyped-def]
# Defines operator schema, name, namespace, and function header
...


# 2. The Generic Implementation
#
Expand All @@ -36,9 +32,8 @@ def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode): # type: ig
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
# implementation here. Note that the function header to the generic function can have specific arguments
# as in the above placeholder
@maxpool1d.impl("cpu") # type: ignore[misc]
@maxpool1d.impl("cuda") # type: ignore[misc]
@maxpool1d.impl_abstract() # type: ignore[misc]
@library.impl("tensorrt::maxpool1d") # type: ignore[misc]
@library.impl_abstract("tensorrt::maxpool1d") # type: ignore[misc]
def maxpool1d_generic(
*args: Any,
**kwargs: Any,
Expand Down

0 comments on commit 6ba9466

Please sign in to comment.