Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support 1D, 2D, and 3D avg and max pooling dynamo converters #2317

Merged
merged 4 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,3 +1378,90 @@ def aten_ops_linear(
weight=args[1],
bias=args_bounds_check(args, 2, None),
)


def avg_pool_param_validator(pool_node: Node) -> bool:
ceil_mode = args_bounds_check(pool_node.args, 4, False)
divisor_override = args_bounds_check(pool_node.args, 6)

if ceil_mode is not False:
_LOGGER.debug(
f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}."
)
return False

if divisor_override is not None:
_LOGGER.debug(
f"Currently we don't support divisor_override, got divisor_override={divisor_override}."
)
return False

return True


# Note: AvgPool1d uses avg_pool2d as it converts to 2D first.
@dynamo_tensorrt_converter(torch.ops.aten.avg_pool1d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.avg_pool2d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.avg_pool3d.default, capability_validator=avg_pool_param_validator) # type: ignore[misc]
def aten_ops_avg_pool(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pool.avg_poolNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
kernel_size=args[1],
stride=args_bounds_check(args, 2, replacement=[]),
padding=args_bounds_check(args, 3, replacement=0),
ceil_mode=args_bounds_check(args, 4, replacement=False),
count_include_pad=args_bounds_check(args, 5, replacement=True),
divisor_override=args_bounds_check(args, 6, replacement=None),
)


def max_pool_param_validator(pool_node: Node) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)

if dilation != 1:
_LOGGER.debug(f"Currently we don't support dilation, got dilation={dilation}.")
return False

if ceil_mode is not False:
_LOGGER.debug(
f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}."
)
return False

return True


# Note: MaxPool1d uses max_pool2d as it converts to 2D first.
@dynamo_tensorrt_converter(torch.ops.aten.max_pool1d.default, capability_validator=max_pool_param_validator) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.max_pool2d.default, capability_validator=max_pool_param_validator) # type: ignore[misc]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could max_pool1d support be added here as well? Schema

Copy link
Collaborator Author

@zewenli98 zewenli98 Sep 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add torch.ops.aten.max_pool1d.default but it won't be used. Even for torch.nn.AvgPool1d, it still calls torch.ops.aten.avg_pool2d.default, as you can see in the test file: https://github.com/pytorch/TensorRT/pull/2317/files#diff-9fce39bc42c66d2866c41665779cab7da0a4d3fe54576925e2b66c17a1cf1ebfR20-R43
But anyways, the 1d schema looks same as others, so I added here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that - I plan to add a lowering pass which will lead us to that converter, so it will still be helpful.

@dynamo_tensorrt_converter(torch.ops.aten.max_pool3d.default, capability_validator=max_pool_param_validator) # type: ignore[misc]
def aten_ops_max_pool(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pool.max_poolNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
kernel_size=args[1],
stride=args_bounds_check(args, 2, replacement=[]),
padding=args_bounds_check(args, 3, replacement=0),
dilation=args_bounds_check(args, 4, replacement=1),
ceil_mode=args_bounds_check(args, 5, replacement=False),
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
matmul,
normalization,
permutation,
pool,
reduce,
select,
shape,
Expand Down
105 changes: 105 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Optional, Sequence, Union

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor


def avg_poolNd(
network: TRTNetwork,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
kernel_size: Sequence[int],
stride: Union[int, Sequence[int]],
padding: Union[int, Sequence[int]] = 0,
ceil_mode: bool = False,
count_include_pad: bool = True,
divisor_override: Optional[int] = None,
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling."

if ceil_mode is not False:
raise RuntimeError("ceil_mode is not yet supported!")

if divisor_override is not None:
raise RuntimeError("divisor_override is not yet supported!")

dim = len(kernel_size)

kernel_size = extend_attr_to_tuple(kernel_size, dim)

if stride == []:
stride = kernel_size
else:
stride = extend_attr_to_tuple(stride, dim)

padding = extend_attr_to_tuple(padding, dim)

# add average pooling layer
pool_layer = network.add_pooling_nd(
input=input,
type=trt.PoolingType.AVERAGE,
window_size=kernel_size,
)

pool_layer.stride_nd = stride
pool_layer.padding_nd = padding
pool_layer.average_count_excludes_padding = not count_include_pad

set_layer_name(pool_layer, target, name, source_ir)
return pool_layer.get_output(0)


def max_poolNd(
network: TRTNetwork,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
kernel_size: Sequence[int],
stride: Union[int, Sequence[int]],
padding: Union[int, Sequence[int]] = 0,
dilation: Union[int, Sequence[int]] = 1,
ceil_mode: bool = False,
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling."

if dilation != 1:
raise RuntimeError("dilation is not yet supported!")

if ceil_mode is not False:
raise RuntimeError("ceil_mode is not yet supported!")

dim = len(kernel_size)

kernel_size = extend_attr_to_tuple(kernel_size, dim)

if stride == []:
stride = kernel_size
else:
stride = extend_attr_to_tuple(stride, dim)

padding = extend_attr_to_tuple(padding, dim)

# add max pooling layer
pool_layer = network.add_pooling_nd(
input=input,
type=trt.PoolingType.MAX,
window_size=kernel_size,
)

pool_layer.stride_nd = stride
pool_layer.padding_nd = padding

set_layer_name(pool_layer, target, name, source_ir)
return pool_layer.get_output(0)
Loading
Loading