Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor rewriter context for MMRazor #1483

Merged
merged 10 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demo/demo_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@FUNCTION_REWRITER.register_rewriter(
func_name='torchvision.models.ResNet._forward_impl')
def forward_of_resnet(ctx, self, x):
def forward_of_resnet(self, x):
"""Rewrite the forward implementation of resnet.

Early return the feature map after two down-sampling steps.
Expand Down
7 changes: 3 additions & 4 deletions docs/en/07-developer-guide/partition_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark

@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
def __forward_impl(self, img, img_metas=None, **kwargs):
...


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base.BaseDetector.forward')
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
def base_detector__forward(self, img, img_metas=None, **kwargs):
...
# call the mark function
return __forward_impl(...)
Expand All @@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark

@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes')
def yolov3_head__get_bboxes(ctx,
self,
def yolov3_head__get_bboxes(self,
pred_maps,
img_metas,
cfg=None,
Expand Down
5 changes: 3 additions & 2 deletions docs/en/07-developer-guide/support_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(ctx, input, *size):
def repeat_static(input, *size):
ctx = FUNCTION_REWRITER.get_context()
origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
Expand Down Expand Up @@ -72,7 +73,7 @@ The mappings between PyTorch and ONNX are defined in PyTorch with symbolic funct

```python
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(ctx, g, self, dim=None):
def squeeze_default(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
Expand Down
5 changes: 3 additions & 2 deletions docs/en/07-developer-guide/test_rewritten_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
def forward_of_base_classifier(ctx, self, img, *args, **kwargs):
def forward_of_base_classifier(self, img, *args, **kwargs):
"""Rewrite `forward` for default backend."""
return self.simple_test(img, {})
```
Expand Down Expand Up @@ -63,7 +63,8 @@ In the first example, the output is generated in Python. Sometimes we may make b
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
def base_segmentor__forward(self, img, img_metas=None, **kwargs):
ctx = FUNCTION_REWRITER.get_context()
if img_metas is None:
img_metas = {}
assert isinstance(img_metas, dict)
Expand Down
7 changes: 3 additions & 4 deletions docs/zh_cn/07-developer-guide/partition_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark

@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
def __forward_impl(self, img, img_metas=None, **kwargs):
...


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base.BaseDetector.forward')
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
def base_detector__forward(self, img, img_metas=None, **kwargs):
...
# call the mark function
return __forward_impl(...)
Expand All @@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark

@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes')
def yolov3_head__get_bboxes(ctx,
self,
def yolov3_head__get_bboxes(self,
pred_maps,
img_metas,
cfg=None,
Expand Down
5 changes: 3 additions & 2 deletions docs/zh_cn/07-developer-guide/support_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ PyTorch 神经网络是用 python 编写的,可以简化算法的开发。但
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(ctx, input, *size):
def repeat_static(input, *size):
ctx = FUNCTION_REWRITER.get_context()
origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
Expand Down Expand Up @@ -67,7 +68,7 @@ PyTorch 和 ONNX 之间的映射是通过 PyTorch 中的符号函数进行定义

```python
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(ctx, g, self, dim=None):
def squeeze_default(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
Expand Down
5 changes: 3 additions & 2 deletions docs/zh_cn/07-developer-guide/test_rewritten_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
def forward_of_base_classifier(ctx, self, img, *args, **kwargs):
def forward_of_base_classifier(self, img, *args, **kwargs):
"""Rewrite `forward` for default backend."""
return self.simple_test(img, {})
```
Expand Down Expand Up @@ -63,7 +63,8 @@ def test_baseclassfier_forward():
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
def base_segmentor__forward(self, img, img_metas=None, **kwargs):
ctx = FUNCTION_REWRITER.get_context()
if img_metas is None:
img_metas = {}
assert isinstance(img_metas, dict)
Expand Down
9 changes: 9 additions & 0 deletions mmdeploy/apis/onnx/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def _add_or_update(cfg: dict, key: str, val: Any):
input_metas, dict
), f'Expect input_metas type is dict, get {type(input_metas)}.'
model_forward = patched_model.forward

grimoire marked this conversation as resolved.
Show resolved Hide resolved
def wrap_forward(forward):

def wrapper(*arg, **kwargs):
return forward(*arg, **kwargs)

return wrapper

patched_model.forward = wrap_forward(patched_model.forward)
patched_model.forward = partial(patched_model.forward,
**input_metas)

Expand Down
13 changes: 10 additions & 3 deletions mmdeploy/apis/onnx/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@


@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph')
def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
def model_to_graph__custom_optimizer(*args, **kwargs):
"""Rewriter of _model_to_graph, add custom passes."""
ctx = FUNCTION_REWRITER.get_context()
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)

custom_passes = getattr(ctx, 'onnx_custom_passes', None)
Expand All @@ -23,10 +24,16 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs):

@FUNCTION_REWRITER.register_rewriter(
'torch._C._jit_pass_onnx_deduplicate_initializers', backend='tensorrt')
def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict,
arg2):
def jit_pass_onnx_deduplicate_initializers__disable(graph, param_dict, arg2):
"""This pass will disable TensorRT topk export.

disable for TensorRT.
"""
return param_dict


@FUNCTION_REWRITER.register_rewriter(
'torch._C._jit_pass_onnx_autograd_function_process')
def jit_pass_onnx_autograd_function_process__disable(graph):
"""Disable process autograph function."""
return
3 changes: 1 addition & 2 deletions mmdeploy/codebase/mmaction/models/recognizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

@FUNCTION_REWRITER.register_rewriter(
'mmaction.models.recognizers.BaseRecognizer.forward')
def base_recognizer__forward(ctx,
self,
def base_recognizer__forward(self,
inputs: Tensor,
data_samples: OptSampleList = None,
mode: str = 'tensor',
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward')
def shufflenetv2_backbone__forward__default(ctx, self, x):
def shufflenetv2_backbone__forward__default(self, x):
"""Rewrite `forward` of InvertedResidual used in shufflenet_v2.

The chunk in original InvertedResidual.forward will convert to dynamic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
func_name= # noqa: E251
'mmcls.models.backbones.vision_transformer.VisionTransformer.forward',
backend=Backend.NCNN.value)
def visiontransformer__forward__ncnn(ctx, self, x):
def visiontransformer__forward__ncnn(self, x):
"""Rewrite `forward` of VisionTransformer for ncnn backend.

The chunk in original VisionTransformer.forward will convert
Expand Down
1 change: 0 additions & 1 deletion mmdeploy/codebase/mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
def base_classifier__forward(
ctx,
self,
batch_inputs: Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmcls/models/necks/gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.necks.GlobalAveragePooling.forward',
backend=Backend.DEFAULT.value)
def gap__forward(ctx, self, inputs):
def gap__forward(self, inputs):
"""Rewrite `forward` of GlobalAveragePooling for default backend.

Replace `view` with `flatten` to export simple onnx graph.
Expand Down
8 changes: 4 additions & 4 deletions mmdeploy/codebase/mmcls/models/utils/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.utils.attention.MultiheadAttention.forward',
backend=Backend.NCNN.value)
def multiheadattention__forward__ncnn(ctx, self, qkv_input):
def multiheadattention__forward__ncnn(self, qkv_input):
"""Rewrite `forward` of MultiheadAttention used in vision_transformer for
ncnn backend.

Expand Down Expand Up @@ -53,12 +53,13 @@ def multiheadattention__forward__ncnn(ctx, self, qkv_input):
func_name= # noqa: E251
'mmcls.models.utils.ShiftWindowMSA.forward',
extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0'))
def shift_window_msa__forward__default(ctx, self, query, hw_shape):
def shift_window_msa__forward__default(self, query, hw_shape):
"""Rewrite forward function of ShiftWindowMSA class for TensorRT.

1. replace dynamic padding with static padding and dynamic slice.
2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability.
"""
ctx = FUNCTION_REWRITER.get_context()
if get_dynamic_axes(ctx.cfg) is None:
# avoid the weird bug of torch to onnx
return ctx.origin_func(self, query, hw_shape)
Expand Down Expand Up @@ -142,8 +143,7 @@ def shift_window_msa__forward__default(ctx, self, query, hw_shape):
func_name= # noqa: E251
'mmcls.models.utils.ShiftWindowMSA.get_attn_mask',
extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0'))
def shift_window_msa__get_attn_mask__default(ctx,
self,
def shift_window_msa__get_attn_mask__default(self,
hw_shape,
window_size,
shift_size,
Expand Down
10 changes: 4 additions & 6 deletions mmdeploy/codebase/mmdet/deploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes',
backend='tensorrt',
extra_checkers=LibVersionChecker('tensorrt', min_version='8'))
def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
def clip_bboxes__trt8(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
max_shape: Union[Tensor, Sequence[int]]):
"""Clip bboxes for onnx. From TensorRT 8 we can do the operators on the
tensors directly.
Expand Down Expand Up @@ -165,7 +165,6 @@ def __pad_with_value_if_necessary(x: Tensor,
'mmdeploy.codebase.mmdet.deploy.utils.__pad_with_value_if_necessary',
backend=Backend.TENSORRT.value)
def __pad_with_value_if_necessary__tensorrt(
ctx,
x: Tensor,
pad_dim: int,
pad_size: int,
Expand Down Expand Up @@ -223,12 +222,12 @@ def symbolic(g, x, inds):
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
backend=Backend.TENSORRT.value)
def __gather_topk__trt(ctx,
*inputs: Sequence[torch.Tensor],
def __gather_topk__trt(*inputs: Sequence[torch.Tensor],
inds: torch.Tensor,
batch_size: int,
is_batched: bool = True) -> Tuple[torch.Tensor]:
"""TensorRT gather_topk."""
ctx = FUNCTION_REWRITER.get_context()
_ = ctx
if is_batched:
index_shape = inds.shape
Expand All @@ -253,8 +252,7 @@ def __gather_topk__trt(ctx,
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
backend=Backend.COREML.value)
def __gather_topk__nonbatch(ctx,
*inputs: Sequence[torch.Tensor],
def __gather_topk__nonbatch(*inputs: Sequence[torch.Tensor],
inds: torch.Tensor,
batch_size: int,
is_batched: bool = True) -> Tuple[torch.Tensor]:
Expand Down
13 changes: 7 additions & 6 deletions mmdeploy/codebase/mmdet/models/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.csp_darknet.Focus.forward')
def focus__forward__default(ctx, self, x):
def focus__forward__default(self, x):
"""Rewrite forward function of Focus class.

Replace slice with transpose.
Expand All @@ -27,7 +27,7 @@ def focus__forward__default(ctx, self, x):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.csp_darknet.Focus.forward',
backend='ncnn')
def focus__forward__ncnn(ctx, self, x):
def focus__forward__ncnn(self, x):
"""Rewrite forward function of Focus class for ncnn.

Focus width and height information into channel space. ncnn does not
Expand Down Expand Up @@ -69,7 +69,7 @@ def focus__forward__ncnn(ctx, self, x):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.WindowMSA.forward',
backend='tensorrt')
def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
def windowmsa__forward__tensorrt(self, x, mask=None):
"""Rewrite forward function of WindowMSA class for TensorRT.

1. replace Gather operation of qkv with split.
Expand All @@ -80,6 +80,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
ctx = FUNCTION_REWRITER.get_context()
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
-1).permute(2, 0, 3, 1, 4).contiguous()
Expand Down Expand Up @@ -129,7 +130,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_reverse',
backend='tensorrt')
def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W):
def shift_window_msa__window_reverse__tensorrt(self, windows, H, W):
"""Rewrite window_reverse function of ShiftWindowMSA class for TensorRT.
For TensorRT, seems radical shape transformations are not allowed. Replace
them with soft ones.
Expand All @@ -155,7 +156,7 @@ def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_partition',
backend='tensorrt')
def shift_window_msa__window_partition__tensorrt(ctx, self, x):
def shift_window_msa__window_partition__tensorrt(self, x):
"""Rewrite window_partition function of ShiftWindowMSA class for TensorRT.
For TensorRT, seems radical shape transformations are not allowed. Replace
them with soft ones.
Expand All @@ -176,7 +177,7 @@ def shift_window_msa__window_partition__tensorrt(ctx, self, x):

@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward')
def shift_window_msa__forward__default(ctx, self, query, hw_shape):
def shift_window_msa__forward__default(self, query, hw_shape):
"""Rewrite forward function of ShiftWindowMSA class.

1. replace dynamic padding with static padding and dynamic slice.
Expand Down
Loading