From 582781f817d95bc33892d64fdaa6d81ded6ec615 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 2 Dec 2022 21:40:48 +0800 Subject: [PATCH 1/8] wip --- mmdeploy/apis/onnx/export.py | 9 ++ mmdeploy/apis/onnx/optimizer.py | 13 ++- .../mmaction/models/recognizers/base.py | 3 +- .../mmcls/models/backbones/shufflenet_v2.py | 2 +- .../models/backbones/vision_transformer.py | 2 +- .../codebase/mmcls/models/classifiers/base.py | 1 - mmdeploy/codebase/mmcls/models/necks/gap.py | 2 +- .../codebase/mmcls/models/utils/attention.py | 8 +- mmdeploy/codebase/mmdet/models/backbones.py | 13 +-- .../models/dense_heads/base_dense_head.py | 7 +- .../models/dense_heads/centernet_head.py | 1 - .../mmdet/models/dense_heads/detr_head.py | 5 +- .../mmdet/models/dense_heads/fovea_head.py | 4 +- .../mmdet/models/dense_heads/gfl_head.py | 4 +- .../models/dense_heads/reppoints_head.py | 5 +- .../mmdet/models/dense_heads/rpn_head.py | 8 +- .../mmdet/models/dense_heads/rtmdet_head.py | 4 +- .../mmdet/models/dense_heads/yolo_head.py | 8 +- .../mmdet/models/dense_heads/yolox_head.py | 7 +- .../mmdet/models/detectors/single_stage.py | 4 +- .../mmdet/models/detectors/two_stage.py | 15 ++-- .../codebase/mmdet/models/layers/bbox_nms.py | 35 +++++--- mmdeploy/codebase/mmdet/models/necks.py | 5 +- .../mmdet/models/roi_heads/bbox_head.py | 21 +++-- .../models/roi_heads/cascade_roi_head.py | 6 +- .../mmdet/models/roi_heads/fcn_mask_head.py | 4 +- .../roi_heads/single_level_roi_extractor.py | 31 +++---- .../models/roi_heads/standard_roi_head.py | 6 +- .../coders/delta_xywh_bbox_coder.py | 9 +- .../coders/distance_point_bbox_coder.py | 6 +- .../task_modules/coders/tblr_bbox_coder.py | 3 +- .../task_modules/prior_generators/anchor.py | 19 ++-- .../prior_generators/point_generator.py | 1 - mmdeploy/codebase/mmdet/models/transformer.py | 2 +- .../mmdet/structures/bbox/transforms.py | 2 +- mmdeploy/core/optimizers/function_marker.py | 14 +-- mmdeploy/core/rewriters/function_rewriter.py | 36 ++++++-- mmdeploy/core/rewriters/rewriter_utils.py | 43 ++++++++++ mmdeploy/core/rewriters/symbolic_rewriter.py | 40 ++++++++- mmdeploy/mmcv/cnn/__init__.py | 1 - mmdeploy/mmcv/cnn/conv2d_adaptive_padding.py | 86 ------------------- mmdeploy/mmcv/cnn/transformer.py | 3 +- mmdeploy/mmcv/ops/deform_conv.py | 6 +- mmdeploy/mmcv/ops/modulated_deform_conv.py | 5 +- mmdeploy/mmcv/ops/nms.py | 72 ++++------------ mmdeploy/mmcv/ops/point_sample.py | 4 +- mmdeploy/mmcv/ops/roi_align.py | 17 ++-- mmdeploy/mmcv/ops/roi_align_rotated.py | 2 +- mmdeploy/mmcv/ops/transformer.py | 2 +- mmdeploy/pytorch/functions/adaptive_pool.py | 6 +- mmdeploy/pytorch/functions/atan2.py | 1 - mmdeploy/pytorch/functions/chunk.py | 7 +- mmdeploy/pytorch/functions/clip.py | 3 +- mmdeploy/pytorch/functions/expand.py | 3 +- mmdeploy/pytorch/functions/flatten.py | 2 +- mmdeploy/pytorch/functions/getattribute.py | 3 +- mmdeploy/pytorch/functions/group_norm.py | 1 - mmdeploy/pytorch/functions/interpolate.py | 10 +-- mmdeploy/pytorch/functions/linear.py | 2 +- mmdeploy/pytorch/functions/masked_fill.py | 5 +- mmdeploy/pytorch/functions/mod.py | 7 +- .../functions/multi_head_attention_forward.py | 1 - mmdeploy/pytorch/functions/normalize.py | 4 +- mmdeploy/pytorch/functions/pad.py | 3 +- mmdeploy/pytorch/functions/repeat.py | 5 +- mmdeploy/pytorch/functions/size.py | 6 +- mmdeploy/pytorch/functions/tensor_getitem.py | 3 +- mmdeploy/pytorch/functions/tensor_setitem.py | 5 +- mmdeploy/pytorch/functions/topk.py | 8 +- mmdeploy/pytorch/functions/triu.py | 3 +- mmdeploy/pytorch/symbolics/adaptive_pool.py | 2 +- mmdeploy/pytorch/symbolics/gelu.py | 13 ++- mmdeploy/pytorch/symbolics/grid_sampler.py | 3 +- mmdeploy/pytorch/symbolics/hardsigmoid.py | 2 +- mmdeploy/pytorch/symbolics/instance_norm.py | 2 +- mmdeploy/pytorch/symbolics/layer_norm.py | 4 +- mmdeploy/pytorch/symbolics/linear.py | 2 +- mmdeploy/pytorch/symbolics/lstm.py | 3 +- mmdeploy/pytorch/symbolics/roll.py | 2 +- mmdeploy/pytorch/symbolics/squeeze.py | 2 +- tests/test_core/test_function_rewriter.py | 28 +++--- tests/test_core/test_symbolic_register.py | 15 ++-- tests/test_mmcv/test_mmcv_cnn.py | 27 ------ tests/test_mmcv/test_mmcv_ops.py | 21 ++--- 84 files changed, 406 insertions(+), 414 deletions(-) delete mode 100644 mmdeploy/mmcv/cnn/conv2d_adaptive_padding.py diff --git a/mmdeploy/apis/onnx/export.py b/mmdeploy/apis/onnx/export.py index 6ca127af90..92a9002d8d 100644 --- a/mmdeploy/apis/onnx/export.py +++ b/mmdeploy/apis/onnx/export.py @@ -116,6 +116,15 @@ def _add_or_update(cfg: dict, key: str, val: Any): input_metas, dict ), f'Expect input_metas type is dict, get {type(input_metas)}.' model_forward = patched_model.forward + + def wrap_forward(forward): + + def wrapper(*arg, **kwargs): + return forward(*arg, **kwargs) + + return wrapper + + patched_model.forward = wrap_forward(patched_model.forward) patched_model.forward = partial(patched_model.forward, **input_metas) diff --git a/mmdeploy/apis/onnx/optimizer.py b/mmdeploy/apis/onnx/optimizer.py index b9d2ead0c0..bfc2cc0abd 100644 --- a/mmdeploy/apis/onnx/optimizer.py +++ b/mmdeploy/apis/onnx/optimizer.py @@ -5,8 +5,9 @@ @FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph') -def model_to_graph__custom_optimizer(ctx, *args, **kwargs): +def model_to_graph__custom_optimizer(*args, **kwargs): """Rewriter of _model_to_graph, add custom passes.""" + ctx = FUNCTION_REWRITER.get_context() graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs) custom_passes = getattr(ctx, 'onnx_custom_passes', None) @@ -23,10 +24,16 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs): @FUNCTION_REWRITER.register_rewriter( 'torch._C._jit_pass_onnx_deduplicate_initializers', backend='tensorrt') -def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict, - arg2): +def jit_pass_onnx_deduplicate_initializers__disable(graph, param_dict, arg2): """This pass will disable TensorRT topk export. disable for TensorRT. """ return param_dict + + +@FUNCTION_REWRITER.register_rewriter( + 'torch._C._jit_pass_onnx_autograd_function_process') +def jit_pass_onnx_autograd_function_process__disable(graph): + """Disable process autograph function.""" + return diff --git a/mmdeploy/codebase/mmaction/models/recognizers/base.py b/mmdeploy/codebase/mmaction/models/recognizers/base.py index 5504f2166c..7e667e128a 100644 --- a/mmdeploy/codebase/mmaction/models/recognizers/base.py +++ b/mmdeploy/codebase/mmaction/models/recognizers/base.py @@ -8,8 +8,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmaction.models.recognizers.BaseRecognizer.forward') -def base_recognizer__forward(ctx, - self, +def base_recognizer__forward(self, inputs: Tensor, data_samples: OptSampleList = None, mode: str = 'tensor', diff --git a/mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py b/mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py index fe3a73d0b2..d47c0c6cfa 100644 --- a/mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py +++ b/mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py @@ -6,7 +6,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward') -def shufflenetv2_backbone__forward__default(ctx, self, x): +def shufflenetv2_backbone__forward__default(self, x): """Rewrite `forward` of InvertedResidual used in shufflenet_v2. The chunk in original InvertedResidual.forward will convert to dynamic diff --git a/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py b/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py index a31853912c..2acf13bb81 100644 --- a/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py +++ b/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py @@ -9,7 +9,7 @@ func_name= # noqa: E251 'mmcls.models.backbones.vision_transformer.VisionTransformer.forward', backend=Backend.NCNN.value) -def visiontransformer__forward__ncnn(ctx, self, x): +def visiontransformer__forward__ncnn(self, x): """Rewrite `forward` of VisionTransformer for ncnn backend. The chunk in original VisionTransformer.forward will convert diff --git a/mmdeploy/codebase/mmcls/models/classifiers/base.py b/mmdeploy/codebase/mmcls/models/classifiers/base.py index 7a324bdbbc..1a54c73f7c 100644 --- a/mmdeploy/codebase/mmcls/models/classifiers/base.py +++ b/mmdeploy/codebase/mmcls/models/classifiers/base.py @@ -12,7 +12,6 @@ @FUNCTION_REWRITER.register_rewriter( 'mmcls.models.classifiers.BaseClassifier.forward', backend='default') def base_classifier__forward( - ctx, self, batch_inputs: Tensor, data_samples: Optional[List[BaseDataElement]] = None, diff --git a/mmdeploy/codebase/mmcls/models/necks/gap.py b/mmdeploy/codebase/mmcls/models/necks/gap.py index d89939def4..f17d0ebac5 100644 --- a/mmdeploy/codebase/mmcls/models/necks/gap.py +++ b/mmdeploy/codebase/mmcls/models/necks/gap.py @@ -9,7 +9,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmcls.models.necks.GlobalAveragePooling.forward', backend=Backend.DEFAULT.value) -def gap__forward(ctx, self, inputs): +def gap__forward(self, inputs): """Rewrite `forward` of GlobalAveragePooling for default backend. Replace `view` with `flatten` to export simple onnx graph. diff --git a/mmdeploy/codebase/mmcls/models/utils/attention.py b/mmdeploy/codebase/mmcls/models/utils/attention.py index edbbc11690..96adea1ad9 100644 --- a/mmdeploy/codebase/mmcls/models/utils/attention.py +++ b/mmdeploy/codebase/mmcls/models/utils/attention.py @@ -11,7 +11,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmcls.models.utils.attention.MultiheadAttention.forward', backend=Backend.NCNN.value) -def multiheadattention__forward__ncnn(ctx, self, qkv_input): +def multiheadattention__forward__ncnn(self, qkv_input): """Rewrite `forward` of MultiheadAttention used in vision_transformer for ncnn backend. @@ -53,12 +53,13 @@ def multiheadattention__forward__ncnn(ctx, self, qkv_input): func_name= # noqa: E251 'mmcls.models.utils.ShiftWindowMSA.forward', extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) -def shift_window_msa__forward__default(ctx, self, query, hw_shape): +def shift_window_msa__forward__default(self, query, hw_shape): """Rewrite forward function of ShiftWindowMSA class for TensorRT. 1. replace dynamic padding with static padding and dynamic slice. 2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability. """ + ctx = FUNCTION_REWRITER.get_context() if get_dynamic_axes(ctx.cfg) is None: # avoid the weird bug of torch to onnx return ctx.origin_func(self, query, hw_shape) @@ -142,8 +143,7 @@ def shift_window_msa__forward__default(ctx, self, query, hw_shape): func_name= # noqa: E251 'mmcls.models.utils.ShiftWindowMSA.get_attn_mask', extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) -def shift_window_msa__get_attn_mask__default(ctx, - self, +def shift_window_msa__get_attn_mask__default(self, hw_shape, window_size, shift_size, diff --git a/mmdeploy/codebase/mmdet/models/backbones.py b/mmdeploy/codebase/mmdet/models/backbones.py index 122a362b97..6f6a72d5ca 100644 --- a/mmdeploy/codebase/mmdet/models/backbones.py +++ b/mmdeploy/codebase/mmdet/models/backbones.py @@ -7,7 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.csp_darknet.Focus.forward') -def focus__forward__default(ctx, self, x): +def focus__forward__default(self, x): """Rewrite forward function of Focus class. Replace slice with transpose. @@ -27,7 +27,7 @@ def focus__forward__default(ctx, self, x): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.csp_darknet.Focus.forward', backend='ncnn') -def focus__forward__ncnn(ctx, self, x): +def focus__forward__ncnn(self, x): """Rewrite forward function of Focus class for ncnn. Focus width and height information into channel space. ncnn does not @@ -69,7 +69,7 @@ def focus__forward__ncnn(ctx, self, x): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.swin.WindowMSA.forward', backend='tensorrt') -def windowmsa__forward__tensorrt(ctx, self, x, mask=None): +def windowmsa__forward__tensorrt(self, x, mask=None): """Rewrite forward function of WindowMSA class for TensorRT. 1. replace Gather operation of qkv with split. @@ -80,6 +80,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None): mask (tensor | None, Optional): mask with shape of (num_windows, Wh*Ww, Wh*Ww), value should be between (-inf, 0]. """ + ctx = FUNCTION_REWRITER.get_context() B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).contiguous() @@ -129,7 +130,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_reverse', backend='tensorrt') -def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W): +def shift_window_msa__window_reverse__tensorrt(self, windows, H, W): """Rewrite window_reverse function of ShiftWindowMSA class for TensorRT. For TensorRT, seems radical shape transformations are not allowed. Replace them with soft ones. @@ -155,7 +156,7 @@ def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_partition', backend='tensorrt') -def shift_window_msa__window_partition__tensorrt(ctx, self, x): +def shift_window_msa__window_partition__tensorrt(self, x): """Rewrite window_partition function of ShiftWindowMSA class for TensorRT. For TensorRT, seems radical shape transformations are not allowed. Replace them with soft ones. @@ -176,7 +177,7 @@ def shift_window_msa__window_partition__tensorrt(ctx, self, x): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward') -def shift_window_msa__forward__default(ctx, self, query, hw_shape): +def shift_window_msa__forward__default(self, query, hw_shape): """Rewrite forward function of ShiftWindowMSA class. 1. replace dynamic padding with static padding and dynamic slice. diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py index 625ca6d58d..efcf1fe0b6 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py @@ -23,7 +23,6 @@ func_name='mmdet.models.dense_heads.base_dense_head.' 'BaseDenseHead.predict_by_feat') def base_dense_head__predict_by_feat( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -65,6 +64,7 @@ def base_dense_head__predict_by_feat( tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg is_dynamic_flag = is_dynamic_shape(deploy_cfg) num_levels = len(cls_scores) @@ -195,7 +195,6 @@ def base_dense_head__predict_by_feat( 'BaseDenseHead.predict_by_feat', backend=Backend.RKNN.value) def base_dense_head__predict_by_feat__rknn( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -237,6 +236,8 @@ def base_dense_head__predict_by_feat__rknn( tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness """ + ctx = FUNCTION_REWRITER.get_context() + # mark nodes for partition @mark('BaseDenseHead', outputs=['BaseDenseHead.cls', 'BaseDenseHead.loc']) def __mark_dense_head(cls_scores, bbox_preds): @@ -321,7 +322,6 @@ def __mark_dense_head(cls_scores, bbox_preds): 'BaseDenseHead.predict_by_feat', backend=Backend.NCNN.value) def base_dense_head__predict_by_feat__ncnn( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -360,6 +360,7 @@ def base_dense_head__predict_by_feat__ncnn( Returns: output__ncnn (Tensor): outputs, shape is [N, num_det, 6]. """ + ctx = FUNCTION_REWRITER.get_context() assert len(cls_scores) == len(bbox_preds) deploy_cfg = ctx.cfg assert not is_dynamic_shape(deploy_cfg), 'base_dense_head for ncnn\ diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/centernet_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/centernet_head.py index e5130489b7..9453b2dac2 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/centernet_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/centernet_head.py @@ -9,7 +9,6 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.centernet_head.CenterNetHead.predict_by_feat') def centernet_head__predict_by_feat__default( - ctx, self, center_heatmap_preds: List[Tensor], wh_preds: List[Tensor], diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py index 3ef050d5c7..8af369913e 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py @@ -10,7 +10,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.DETRHead.forward_single') -def detrhead__forward_single__default(ctx, self, x, img_metas): +def detrhead__forward_single__default(self, x, img_metas): """forward_single of DETRHead. Ease the mask computation @@ -35,8 +35,7 @@ def detrhead__forward_single__default(ctx, self, x, img_metas): @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.DETRHead.predict_by_feat') -def detrhead__predict_by_feat__default(ctx, - self, +def detrhead__predict_by_feat__default(self, all_cls_scores_list: List[Tensor], all_bbox_preds_list: List[Tensor], batch_img_metas: List[dict], diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py index 110a4045f1..544aba0acf 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py @@ -13,8 +13,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.fovea_head.FoveaHead.predict_by_feat') -def fovea_head__predict_by_feat(ctx, - self, +def fovea_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], score_factors: Optional[List[Tensor]] = None, @@ -49,6 +48,7 @@ def fovea_head__predict_by_feat(ctx, `dets` of shape [N, num_det, 5] and `labels` of shape [N, num_det]. """ + ctx = FUNCTION_REWRITER.get_context() assert len(cls_scores) == len(bbox_preds) cfg = self.test_cfg if cfg is None else cfg num_levels = len(cls_scores) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py index 4583d01019..fce43d8c66 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py @@ -18,8 +18,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.gfl_head.' 'GFLHead.predict_by_feat') -def gfl_head__predict_by_feat(ctx, - self, +def gfl_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], score_factors: Optional[List[Tensor]] = None, @@ -58,6 +57,7 @@ def gfl_head__predict_by_feat(ctx, tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg is_dynamic_flag = is_dynamic_shape(deploy_cfg) backend = get_backend(deploy_cfg) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py index e24998a925..5e16741214 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py @@ -35,12 +35,13 @@ def _bbox_post_decode(bboxes: torch.Tensor, max_shape: Sequence[int]): @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.reppoints_head.RepPointsHead.points2bbox') -def reppoints_head__points2bbox(ctx, self, pts, y_first=True): +def reppoints_head__points2bbox(self, pts, y_first=True): """Rewrite of `points2bbox` in `RepPointsHead`. Use `self.moment_transfer` in `points2bbox` will cause error: RuntimeError: Input, output and indices must be on the current device """ + ctx = FUNCTION_REWRITER.get_context() update_moment = hasattr(self, 'moment_transfer') if update_moment: moment_transfer = self.moment_transfer @@ -55,7 +56,6 @@ def reppoints_head__points2bbox(ctx, self, pts, y_first=True): @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.reppoints_head.RepPointsHead.predict_by_feat') def reppoints_head__predict_by_feat( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -91,6 +91,7 @@ def reppoints_head__predict_by_feat( `dets` of shape [N, num_det, 5] and `labels` of shape [N, num_det]. """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg is_dynamic_flag = is_dynamic_shape(deploy_cfg) num_levels = len(cls_scores) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py index 05f2484867..bb962aff22 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py @@ -16,8 +16,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.rpn_head.' 'RPNHead.predict_by_feat') -def rpn_head__predict_by_feat(ctx, - self, +def rpn_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], score_factors: Optional[List[Tensor]] = None, @@ -61,6 +60,7 @@ def rpn_head__predict_by_feat(ctx, tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness """ + ctx = FUNCTION_REWRITER.get_context() img_metas = batch_img_metas assert len(cls_scores) == len(bbox_preds) deploy_cfg = ctx.cfg @@ -163,8 +163,7 @@ def rpn_head__predict_by_feat(ctx, # TODO: Fix for 1.x @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.RPNHead.get_bboxes', backend=Backend.NCNN.value) -def rpn_head__get_bboxes__ncnn(ctx, - self, +def rpn_head__get_bboxes__ncnn(self, cls_scores, bbox_preds, img_metas, @@ -201,6 +200,7 @@ def rpn_head__get_bboxes__ncnn(ctx, Else: tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores """ + ctx = FUNCTION_REWRITER.get_context() assert len(cls_scores) == len(bbox_preds) deploy_cfg = ctx.cfg assert not is_dynamic_shape(deploy_cfg) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py index 3ad883de53..7e69acc228 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py @@ -14,8 +14,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.rtmdet_head.' 'RTMDetHead.predict_by_feat') -def rtmdet_head__predict_by_feat(ctx, - self, +def rtmdet_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], batch_img_metas: Optional[List[dict]] = None, @@ -52,6 +51,7 @@ def rtmdet_head__predict_by_feat(ctx, tensor in the tuple is (N, num_box), and each element represents the class label of the corresponding box. """ + ctx = FUNCTION_REWRITER.get_context() assert len(cls_scores) == len(bbox_preds) device = cls_scores[0].device cfg = self.test_cfg if cfg is None else cfg diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py index 33cd10b61a..9063777f51 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py @@ -15,8 +15,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.yolo_head.' 'YOLOV3Head.predict_by_feat') -def yolov3_head__predict_by_feat(ctx, - self, +def yolov3_head__predict_by_feat(self, pred_maps: Sequence[Tensor], cfg: OptConfigType = None, rescale: bool = False, @@ -47,6 +46,7 @@ def yolov3_head__predict_by_feat(ctx, Else: tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg # mark pred_maps @@ -152,8 +152,7 @@ def __mark_pred_maps(pred_maps): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.YOLOV3Head.predict_by_feat', backend=Backend.NCNN.value) -def yolov3_head__predict_by_feat__ncnn(ctx, - self, +def yolov3_head__predict_by_feat__ncnn(self, pred_maps, with_nms=True, cfg=None, @@ -186,6 +185,7 @@ def yolov3_head__predict_by_feat__ncnn(ctx, fore-ground class label in Yolov3DetectionOutput starts from `1`. x1, y1, x2, y2 are normalized in range(0,1). """ + ctx = FUNCTION_REWRITER.get_context() num_levels = len(pred_maps) cfg = self.test_cfg if cfg is None else cfg post_params = get_post_processing_params(ctx.cfg) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py index 47a696fbbf..98a6867934 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py @@ -15,8 +15,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.yolox_head.' 'YOLOXHead.predict_by_feat') -def yolox_head__predict_by_feat(ctx, - self, +def yolox_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], objectnesses: Optional[List[Tensor]], @@ -57,6 +56,8 @@ def yolox_head__predict_by_feat(ctx, tensor in the tuple is (N, num_box), and each element represents the class label of the corresponding box. """ + ctx = FUNCTION_REWRITER.get_context() + # mark pred_maps @mark('yolo_head', inputs=['cls_scores', 'bbox_preds', 'objectnesses']) def __mark_pred_maps(cls_scores, bbox_preds, objectnesses): @@ -118,7 +119,6 @@ def __mark_pred_maps(cls_scores, bbox_preds, objectnesses): 'YOLOXHead.predict_by_feat', backend=Backend.NCNN.value) def yolox_head__predict_by_feat__ncnn( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -162,6 +162,7 @@ def yolox_head__predict_by_feat__ncnn( Returns: output__ncnn (Tensor): outputs, shape is [N, num_det, 6]. """ + ctx = FUNCTION_REWRITER.get_context() from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward from mmdeploy.utils import get_root_logger from mmdeploy.utils.config_utils import is_dynamic_shape diff --git a/mmdeploy/codebase/mmdet/models/detectors/single_stage.py b/mmdeploy/codebase/mmdet/models/detectors/single_stage.py index adfb6831f6..9f256fc2d8 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/single_stage.py +++ b/mmdeploy/codebase/mmdet/models/detectors/single_stage.py @@ -27,8 +27,7 @@ def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs): @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.single_stage.SingleStageDetector.forward') -def single_stage_detector__forward(ctx, - self, +def single_stage_detector__forward(self, batch_inputs: torch.Tensor, data_samples: OptSampleList = None, mode: str = 'tensor', @@ -53,6 +52,7 @@ def single_stage_detector__forward(ctx, - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). """ + ctx = FUNCTION_REWRITER.get_context() data_samples = copy.deepcopy(data_samples) if data_samples is None: data_samples = [DetDataSample()] diff --git a/mmdeploy/codebase/mmdet/models/detectors/two_stage.py b/mmdeploy/codebase/mmdet/models/detectors/two_stage.py index 9b20fed83c..d0bd140003 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/two_stage.py +++ b/mmdeploy/codebase/mmdet/models/detectors/two_stage.py @@ -11,8 +11,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.two_stage.TwoStageDetector.extract_feat') -@mark('extract_feat', inputs='img', outputs='feat') -def two_stage_detector__extract_feat(ctx, self, img): +def two_stage_detector__extract_feat(self, img): """Rewrite `extract_feat` for default backend. This function uses the specific `extract_feat` function for the two @@ -27,13 +26,18 @@ def two_stage_detector__extract_feat(ctx, self, img): list[Tensor]: Each item with shape (N, C, H, W) corresponds one level of backbone and neck features. """ - return ctx.origin_func(self, img) + ctx = FUNCTION_REWRITER.get_context() + + @mark('extract_feat', inputs='img', outputs='feat') + def __extract_feat_impl(self, img): + return ctx.origin_func(self, img) + + return __extract_feat_impl(self, img) @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.two_stage.TwoStageDetector.forward') -def two_stage_detector__forward(ctx, - self, +def two_stage_detector__forward(self, batch_inputs: torch.Tensor, data_samples: OptSampleList = None, mode: str = 'tensor', @@ -58,6 +62,7 @@ def two_stage_detector__forward(ctx, - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). """ + ctx = FUNCTION_REWRITER.get_context() data_samples = copy.deepcopy(data_samples) deploy_cfg = ctx.cfg diff --git a/mmdeploy/codebase/mmdet/models/layers/bbox_nms.py b/mmdeploy/codebase/mmdet/models/layers/bbox_nms.py index b66f1f7655..974e10dcf1 100644 --- a/mmdeploy/codebase/mmdet/models/layers/bbox_nms.py +++ b/mmdeploy/codebase/mmdet/models/layers/bbox_nms.py @@ -88,7 +88,6 @@ def _multiclass_nms(boxes: Tensor, shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes, 4). """ - max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class]) iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32) score_threshold = torch.tensor([score_threshold], dtype=torch.float32) batch_size = scores.shape[0] @@ -122,7 +121,6 @@ def _multiclass_nms_single(boxes: Tensor, Single batch nms could be optimized. """ - max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class]) iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32) score_threshold = torch.tensor([score_threshold], dtype=torch.float32) @@ -166,8 +164,7 @@ def _multiclass_nms_single(boxes: Tensor, @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.codebase.mmdet.models.layers.bbox_nms._multiclass_nms') -def multiclass_nms__default(ctx, - boxes: Tensor, +def multiclass_nms__default(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, @@ -199,6 +196,7 @@ def multiclass_nms__default(ctx, tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5] and `labels` of shape [N, num_det]. """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg batch_size = boxes.size(0) if not is_dynamic_batch(deploy_cfg) and batch_size == 1: @@ -224,8 +222,7 @@ def multiclass_nms__default(ctx, @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.codebase.mmdet.models.layers.bbox_nms._multiclass_nms', backend='tensorrt') -def multiclass_nms_static(ctx, - boxes: Tensor, +def multiclass_nms_static(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, @@ -271,16 +268,28 @@ def multiclass_nms_static(ctx, @mark('multiclass_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels']) -def multiclass_nms(*args, **kwargs): +def multiclass_nms(boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1): """Wrapper function for `_multiclass_nms`.""" - return _multiclass_nms(*args, **kwargs) + return _multiclass_nms( + boxes, + scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.codebase.mmdet.models.layers.bbox_nms._multiclass_nms', backend=Backend.COREML.value) -def multiclass_nms__coreml(ctx, - boxes: Tensor, +def multiclass_nms__coreml(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, @@ -340,8 +349,7 @@ def _xywh2xyxy(boxes): @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.codebase.mmdet.models.layers.bbox_nms._multiclass_nms', ir=IR.TORCHSCRIPT) -def multiclass_nms__torchscript(ctx, - boxes: Tensor, +def multiclass_nms__torchscript(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, @@ -441,8 +449,7 @@ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms', backend='ascend') -def multiclass_nms__ascend(ctx, - boxes: Tensor, +def multiclass_nms__ascend(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, diff --git a/mmdeploy/codebase/mmdet/models/necks.py b/mmdeploy/codebase/mmdet/models/necks.py index adc40fa12d..79c1d117e5 100644 --- a/mmdeploy/codebase/mmdet/models/necks.py +++ b/mmdeploy/codebase/mmdet/models/necks.py @@ -7,7 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.necks.ssd_neck.L2Norm.forward') -def l2norm__forward__default(ctx, self, x): +def l2norm__forward__default(self, x): """Default rewriter for l2norm. Implement with functinoal.normalize . @@ -19,11 +19,12 @@ def l2norm__forward__default(ctx, self, x): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.necks.ssd_neck.L2Norm.forward', backend=Backend.TENSORRT.value) -def l2norm__forward__tensorrt(ctx, self, x): +def l2norm__forward__tensorrt(self, x): """rewrite `l2norm` for TensorRT. TensorRT7 does not support dynamic clamp, which is used in normalize. """ + ctx = FUNCTION_REWRITER.get_context() logger = get_root_logger() trt_version_major = 8 try: diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py index 28dbeb3266..450f20b4cd 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py @@ -16,11 +16,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.bbox_heads.convfc_bbox_head.ConvFCBBoxHead.forward' ) -@mark( - 'bbox_head_forward', - inputs=['bbox_feats'], - outputs=['cls_score', 'bbox_pred']) -def bbox_head__forward(ctx, self, x): +def bbox_head__forward(self, x): """Rewrite `forward` for default backend. This function uses the specific `forward` function for the BBoxHead @@ -36,13 +32,21 @@ def bbox_head__forward(ctx, self, x): has shape (N, num_det, num_cls) and the bbox_pred has shape (N, num_det, 4). """ - return ctx.origin_func(self, x) + ctx = FUNCTION_REWRITER.get_context() + + @mark( + 'bbox_head_forward', + inputs=['bbox_feats'], + outputs=['cls_score', 'bbox_pred']) + def __forward(self, x): + return ctx.origin_func(self, x) + + return __forward(self, x) @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.bbox_heads.bbox_head.BBoxHead.predict_by_feat') -def bbox_head__predict_by_feat(ctx, - self, +def bbox_head__predict_by_feat(self, rois: Tuple[Tensor], cls_scores: Tuple[Tensor], bbox_preds: Tuple[Tensor], @@ -74,6 +78,7 @@ def bbox_head__predict_by_feat(ctx, - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). """ + ctx = FUNCTION_REWRITER.get_context() assert rois.ndim == 3, 'Only support export two stage ' \ 'model to ONNX ' \ 'with batch dimension. ' diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py index 1c8249b20e..6947570f80 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py @@ -10,8 +10,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.cascade_roi_head.CascadeRoIHead.predict_bbox') -def cascade_roi_head__predict_bbox(ctx, - self, +def cascade_roi_head__predict_bbox(self, x: Tuple[Tensor], batch_img_metas: List[dict], rpn_results_list: List[Tensor], @@ -83,8 +82,7 @@ def cascade_roi_head__predict_bbox(ctx, @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.cascade_roi_head.CascadeRoIHead.predict_mask') -def cascade_roi_head__predict_mask(ctx, - self, +def cascade_roi_head__predict_mask(self, x: Tuple[Tensor], batch_img_metas: List[dict], results_list: List[Tensor], diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py index 360faeb1aa..9371ff552e 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py @@ -14,8 +14,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.' 'mask_heads.fcn_mask_head.FCNMaskHead.predict_by_feat') -def fcn_mask_head__predict_by_feat(ctx, - self, +def fcn_mask_head__predict_by_feat(self, mask_preds: Tuple[Tensor], results_list: List[Tensor], batch_img_metas: List[dict], @@ -48,6 +47,7 @@ def fcn_mask_head__predict_by_feat(ctx, (num_instances, ). - masks (Tensor): Has a shape (num_instances, H, W). """ + ctx = FUNCTION_REWRITER.get_context() ori_shape = batch_img_metas[0]['img_shape'] dets, det_labels = results_list dets = dets.view(-1, 5) diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py index 2f3da854c0..d8f53ad0fd 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py @@ -61,13 +61,12 @@ def forward(g, *args): (num_proposals, channel, output_size[1], output_size[0])) +@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.roi_extractors.' 'single_level_roi_extractor.SingleRoIExtractor.forward', backend='tensorrt') -@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) -def single_roi_extractor__forward__tensorrt(ctx, - self, +def single_roi_extractor__forward__tensorrt(self, feats, rois, roi_scale_factor=None): @@ -154,8 +153,7 @@ def forward(ctx, *args): 'mmdet.models.roi_heads.roi_extractors.' 'single_level_roi_extractor.SingleRoIExtractor.forward', backend='ascend') -def single_roi_extractor__forward__ascend(ctx, - self, +def single_roi_extractor__forward__ascend(self, feats, rois, roi_scale_factor=None): @@ -185,14 +183,10 @@ def single_roi_extractor__forward__ascend(ctx, finest_scale, featmap_strides, aligned) +@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward') -@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) -def single_roi_extractor__forward(ctx, - self, - feats, - rois, - roi_scale_factor=None): +def single_roi_extractor__forward(self, feats, rois, roi_scale_factor=None): """Rewrite `forward` of SingleRoIExtractor for default backend. Rewrite this function to: @@ -206,6 +200,8 @@ def single_roi_extractor__forward(ctx, 3. use the roi align in torhcvision to accelerate the inference. """ + ctx = FUNCTION_REWRITER.get_context( + 'mmdet.models.roi_heads.SingleRoIExtractor.forward') backend = get_backend(ctx.cfg) out_size = self.roi_layers[0].output_size num_levels = len(feats) @@ -268,8 +264,8 @@ def forward(g, output_size, featmap_strides, sample_num, rois, *feats): @staticmethod def symbolic(g, output_size, featmap_strides, sample_num, rois, *feats): """Symbolic function for creating onnx op.""" - from torch.onnx.symbolic_helper import _slice_helper - rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) + from torch.onnx.symbolic_opset10 import _slice + rois = _slice(g, rois, axes=[1], starts=[1], ends=[5]) domain = 'org.openvinotoolkit' op_name = 'ExperimentalDetectronROIFeatureExtractor' roi_feats = g.op( @@ -291,8 +287,7 @@ def symbolic(g, output_size, featmap_strides, sample_num, rois, *feats): 'mmdet.models.roi_heads.roi_extractors.' 'single_level_roi_extractor.SingleRoIExtractor.forward', backend='openvino') -def single_roi_extractor__forward__openvino(ctx, - self, +def single_roi_extractor__forward__openvino(self, feats, rois, roi_scale_factor=None): @@ -301,6 +296,7 @@ def single_roi_extractor__forward__openvino(ctx, This function uses ExperimentalDetectronROIFeatureExtractor for OpenVINO. """ + ctx = FUNCTION_REWRITER.get_context() # Adding original output to SingleRoIExtractorOpenVINO. state = torch._C._get_tracing_state() @@ -317,12 +313,11 @@ def single_roi_extractor__forward__openvino(ctx, return result +@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward', backend=Backend.COREML.value) -@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) -def single_roi_extractor__forward__coreml(ctx, - self, +def single_roi_extractor__forward__coreml(self, feats, rois, roi_scale_factor=None): diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/standard_roi_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/standard_roi_head.py index 86e78caab8..4f2b9bb600 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/standard_roi_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/standard_roi_head.py @@ -10,8 +10,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.standard_roi_head.StandardRoIHead.predict_bbox') -def standard_roi_head__predict_bbox(ctx, - self, +def standard_roi_head__predict_bbox(self, x: Tuple[Tensor], batch_img_metas: List[dict], rpn_results_list: List[Tensor], @@ -71,8 +70,7 @@ def standard_roi_head__predict_bbox(ctx, @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.standard_roi_head.StandardRoIHead.predict_mask') -def standard_roi_head__predict_mask(ctx, - self, +def standard_roi_head__predict_mask(self, x: Tuple[Tensor], batch_img_metas: List[dict], results_list: List[Tensor], diff --git a/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py b/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py index e4e8f8e826..8a53944233 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py @@ -9,8 +9,7 @@ func_name='mmdet.models.task_modules.coders.delta_xywh_bbox_coder.' 'DeltaXYWHBBoxCoder.decode', backend='default') -def deltaxywhbboxcoder__decode(ctx, - self, +def deltaxywhbboxcoder__decode(self, bboxes, pred_bboxes, max_shape=None, @@ -51,8 +50,7 @@ def deltaxywhbboxcoder__decode(ctx, func_name='mmdet.models.task_modules.coders' '.delta_xywh_bbox_coder.delta2bbox', backend='default') -def delta2bbox(ctx, - rois, +def delta2bbox(rois, deltas, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.), @@ -143,8 +141,7 @@ def delta2bbox(ctx, func_name='mmdet.models.task_modules.coders.' 'delta_xywh_bbox_coder.delta2bbox', backend='ncnn') -def delta2bbox__ncnn(ctx, - rois, +def delta2bbox__ncnn(rois, deltas, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.), diff --git a/mmdeploy/codebase/mmdet/models/task_modules/coders/distance_point_bbox_coder.py b/mmdeploy/codebase/mmdet/models/task_modules/coders/distance_point_bbox_coder.py index 41da1bdfcf..8b8bbdb0e4 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/coders/distance_point_bbox_coder.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/coders/distance_point_bbox_coder.py @@ -8,11 +8,7 @@ func_name='mmdet.models.task_modules.coders.distance_point_bbox_coder' '.DistancePointBBoxCoder.decode', backend='default') -def distancepointbboxcoder__decode(ctx, - self, - points, - pred_bboxes, - max_shape=None): +def distancepointbboxcoder__decode(self, points, pred_bboxes, max_shape=None): """Rewrite `mmdet.models.task_modules.coders.distance_point_bbox_coder. \ DistancePointBBoxCoder.decode` diff --git a/mmdeploy/codebase/mmdet/models/task_modules/coders/tblr_bbox_coder.py b/mmdeploy/codebase/mmdet/models/task_modules/coders/tblr_bbox_coder.py index c5ca8cd37e..b0f56676cb 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/coders/tblr_bbox_coder.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/coders/tblr_bbox_coder.py @@ -7,8 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.task_modules.coders.tblr_bbox_coder.tblr2bboxes', backend='default') -def tblr2bboxes(ctx, - priors, +def tblr2bboxes(priors, tblr, normalizer=4.0, normalize_by_wh=True, diff --git a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py index 7550dd03b0..e3432f0f3e 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py @@ -13,6 +13,13 @@ class GridPriorsTRTOp(torch.autograd.Function): def forward(ctx, base_anchors, feat_h, feat_w, stride_h: int, stride_w: int): """Generate grid priors by base anchors.""" + + # torch>=1.13 has runtime error + # when using torch.arange in autograd function + output = getattr(GridPriorsTRTOp, 'output', None) + if output is not None: + return output + device = base_anchors.device dtype = base_anchors.dtype shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w @@ -41,8 +48,8 @@ def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int, stride_w: int): """Map ops to onnx symbolics.""" # zero_h and zero_w is used to provide shape to GridPriorsTRT - feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0]) - feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0]) + feat_h = g.op('Unsqueeze', feat_h, axes_i=[0]) + feat_w = g.op('Unsqueeze', feat_w, axes_i=[0]) zero_h = g.op( 'ConstantOfShape', feat_h, @@ -70,7 +77,6 @@ def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int, 'AnchorGenerator.single_level_grid_priors', backend='tensorrt') def anchorgenerator__single_level_grid_priors__trt( - ctx, self, featmap_size: Tuple[int], level_idx: int, @@ -91,10 +97,13 @@ def anchorgenerator__single_level_grid_priors__trt( Returns: torch.Tensor: Anchors in the overall feature maps. """ + ctx = FUNCTION_REWRITER.get_context() feat_h, feat_w = featmap_size + output = ctx.origin_func(self, featmap_size, level_idx, dtype, device).data if isinstance(feat_h, int) and isinstance(feat_w, int): - return ctx.origin_func(self, featmap_size, level_idx, dtype, - device).data + return output base_anchors = self.base_anchors[level_idx].to(device).to(dtype) stride_w, stride_h = self.strides[level_idx] + + GridPriorsTRTOp.output = output return grid_priors_trt(base_anchors, feat_h, feat_w, stride_h, stride_w) diff --git a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/point_generator.py b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/point_generator.py index 91e54692ce..bfeb91bdc5 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/point_generator.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/point_generator.py @@ -10,7 +10,6 @@ '.single_level_grid_priors', backend=Backend.TENSORRT.value) def mlvl_point_generator__single_level_grid_priors__tensorrt( - ctx, self, featmap_size, level_idx, diff --git a/mmdeploy/codebase/mmdet/models/transformer.py b/mmdeploy/codebase/mmdet/models/transformer.py index 7ff62c675a..e89a506699 100644 --- a/mmdeploy/codebase/mmdet/models/transformer.py +++ b/mmdeploy/codebase/mmdet/models/transformer.py @@ -7,7 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.utils.transformer.PatchMerging.forward', backend='tensorrt') -def patch_merging__forward__tensorrt(ctx, self, x, input_size): +def patch_merging__forward__tensorrt(self, x, input_size): """Rewrite forward function of PatchMerging class for TensorRT. In original implementation, mmdet applies nn.unfold to accelerate the inference. However, the onnx graph of it can not be parsed correctly by TensorRT. In diff --git a/mmdeploy/codebase/mmdet/structures/bbox/transforms.py b/mmdeploy/codebase/mmdet/structures/bbox/transforms.py index 727ce0f456..aed156f512 100644 --- a/mmdeploy/codebase/mmdet/structures/bbox/transforms.py +++ b/mmdeploy/codebase/mmdet/structures/bbox/transforms.py @@ -8,7 +8,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.structures.bbox.transforms.distance2bbox' # noqa ) -def distance2bbox__default(ctx, points, distance, max_shape=None): +def distance2bbox__default(points, distance, max_shape=None): """Rewrite `mmdet.core.bbox.transforms.distance2bbox` Decode distance prediction to bounding box. diff --git a/mmdeploy/core/optimizers/function_marker.py b/mmdeploy/core/optimizers/function_marker.py index 57ab7ff19c..0417fccfd9 100644 --- a/mmdeploy/core/optimizers/function_marker.py +++ b/mmdeploy/core/optimizers/function_marker.py @@ -62,18 +62,20 @@ def forward(ctx, x, *args) -> torch.Tensor: @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.core.optimizers.function_marker.Mark.symbolic') -def mark_symbolic(rewriter, g, x, *args): +def mark_symbolic(g, x, *args): """Rewrite symbolic of mark op.""" - if cfg_apply_marks(rewriter.cfg): - return rewriter.origin_func(g, x, *args) + ctx = FUNCTION_REWRITER.get_context() + if cfg_apply_marks(ctx.cfg): + return ctx.origin_func(g, x, *args) return x @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.core.optimizers.function_marker.Mark.forward') -def forward_of_mark(rewriter, ctx, x, dtype, shape, func, func_id, type, name, - id, attrs) -> torch.Tensor: +def forward_of_mark(ctx, x, dtype, shape, func, func_id, type, name, id, + attrs) -> torch.Tensor: """Rewrite forward of mark op.""" + rewriter = FUNCTION_REWRITER.get_context() deploy_cfg = rewriter.cfg # save calib data apply_marks = cfg_apply_marks(deploy_cfg) @@ -182,7 +184,7 @@ def impl(ys, prefix, level): @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.core.optimizers.function_marker.mark_tensors', ir=IR.TORCHSCRIPT) -def remove_mark__torchscript(ctx, xs: Any, *args, **kwargs): +def remove_mark__torchscript(xs: Any, *args, **kwargs): """Disable all marks for TorchScript backend. As the Node `mark` is not able to be traced, we just return original input diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index b623476f36..e6eb207550 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -4,7 +4,8 @@ from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, - import_function) + copy_function, get_frame_qual_name, + get_func_qual_name, import_function) def _replace_all_obj(obj: Any, @@ -114,6 +115,7 @@ class FunctionRewriter: def __init__(self): self._registry = RewriterRegistry() + self._func_contexts = {} def register_rewriter( self, @@ -140,6 +142,7 @@ def register_rewriter( def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): """The implementation of function rewrite.""" + self._func_contexts = {} # Get current records functions_records = self._registry.get_records(env) @@ -181,15 +184,20 @@ def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): # Create context_caller rewrite_function = record_dict['_object'] + rewrite_function = copy_function(rewrite_function) extra_kwargs = kwargs.copy() extra_kwargs.update(record_dict) - context_caller = ContextCaller( - rewrite_function, origin_func, cfg, - **extra_kwargs).get_wrapped_caller() + context_caller = ContextCaller(rewrite_function, origin_func, + cfg, **extra_kwargs) + + qualname = get_func_qual_name(rewrite_function) + self._func_contexts[qualname] = context_caller + self._func_contexts[function_path] = context_caller # Cache new the function to avoid homonymic bug new_functions.append( - dict(func_path=function_path, origin_func=context_caller)) + dict( + func_path=function_path, origin_func=rewrite_function)) for func_dict in new_functions: function_path = func_dict['func_path'] @@ -205,3 +213,21 @@ def exit(self): _set_func(func_path, func) for func_path in self._additional_functions: _del_func(func_path) + + self._func_contexts = {} + + def get_context(self, key: Optional[str] = None) -> ContextCaller: + """Get the context of rewriter. + + Args: + key: key to the context. + + Returns: + ContextCaller: context of function + """ + if key is None: + key = get_frame_qual_name(2) + ctx = self._func_contexts.get(key, None) + if ctx is None: + get_root_logger().warning(f'Can not found context of {key}') + return ctx diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 7c5e4e45ee..97cf931609 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import functools import inspect +import types import warnings from abc import ABCMeta, abstractmethod from functools import wraps @@ -379,3 +381,44 @@ def wrapper(*args, **kwargs): return self.func(self, *args, **kwargs) return wrapper + + +def get_func_qual_name(func: Callable) -> str: + """get function name.""" + assert isinstance(func, Callable), f'{func} is not a Callable object.' + _func_name = None + if hasattr(func, '__qualname__'): + _func_name = f'{func.__module__}.{func.__qualname__}' + elif hasattr(func, '__class__'): + _func_name = func.__class__ + else: + _func_name = str(func) + return _func_name + + +def get_frame_qual_name(top: int = 1) -> str: + """get frame name.""" + frameinfo = inspect.stack()[top] + frame = frameinfo.frame + + g_vars = frame.f_globals + func_name = frameinfo.function + assert func_name in g_vars, \ + f'Can not find function: {func_name} in global.' + func = g_vars[func_name] + module_name = inspect.getmodule(func).__name__ + + return f'{module_name}.{func_name}' + + +def copy_function(f: types.FunctionType): + """Copy the function.""" + g = types.FunctionType( + f.__code__, + f.__globals__, + name=f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = f.__kwdefaults__ + return g diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index dfcc28f761..b4daa16104 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -7,7 +7,8 @@ from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, - eval_with_import) + copy_function, eval_with_import, + get_frame_qual_name, get_func_qual_name) class SymbolicRewriter: @@ -34,6 +35,7 @@ class SymbolicRewriter: def __init__(self) -> None: self._registry = RewriterRegistry() + self._func_contexts = {} def register_symbolic(self, func_name: str, @@ -75,6 +77,9 @@ def enter(self, opset: int = 11, **kwargs): """The implementation of symbolic register.""" + # clear context + self._func_contexts = {} + # Get current records symbolic_records = self._registry.get_records(env) @@ -84,19 +89,27 @@ def enter(self, for function_name, record_dict in symbolic_records: symbolic_function = record_dict['_object'] + symbolic_function = copy_function(symbolic_function) arg_descriptors = record_dict['arg_descriptors'] extra_kwargs = kwargs.copy() extra_kwargs.update(record_dict) context_caller = ContextCaller(symbolic_function, None, cfg, **extra_kwargs) + + # register context + qualname = get_func_qual_name(symbolic_function) + self._func_contexts[qualname] = context_caller + self._func_contexts[function_name] = context_caller + if arg_descriptors is not None and len(arg_descriptors) > 0: - context_caller = parse_args(*arg_descriptors)(context_caller) + symbolic_function = parse_args(*arg_descriptors)( + symbolic_function) is_pytorch = record_dict['is_pytorch'] if is_pytorch: from torch.onnx import register_custom_op_symbolic register_custom_op_symbolic(f'::{function_name}', - context_caller, opset) + symbolic_function, opset) # Save domain and version self._pytorch_symbolic.append((function_name, '', opset)) @@ -123,7 +136,7 @@ def enter(self, self._extra_symbolic.append((origin_func, origin_symbolic)) # Cache new the function to avoid homonymic bug - new_functions.append((origin_func, context_caller)) + new_functions.append((origin_func, symbolic_function)) for origin_func, new_func in new_functions: origin_symbolic = getattr(origin_func, 'symbolic', None) @@ -132,6 +145,9 @@ def enter(self, def exit(self): """The implementation of symbolic unregister.""" + # clear context + self._func_contexts = {} + # Unregister pytorch op if hasattr(torch.onnx, 'unregister_custom_op_symbolic'): from torch.onnx import unregister_custom_op_symbolic @@ -149,3 +165,19 @@ def exit(self): # Unregister custom op for origin_func, origin_symbolic in self._extra_symbolic: origin_func.symbolic = origin_symbolic + + def get_context(self, key: Optional[str] = None) -> ContextCaller: + """Get the context of rewriter. + + Args: + key: key to the context. + + Returns: + ContextCaller: context of function + """ + if key is None: + key = get_frame_qual_name(2) + ctx = self._func_contexts.get(key, None) + if ctx is None: + get_root_logger().warning(f'Can not found context of {key}') + return ctx diff --git a/mmdeploy/mmcv/cnn/__init__.py b/mmdeploy/mmcv/cnn/__init__.py index 3b777d8b0c..917a4a6df1 100644 --- a/mmdeploy/mmcv/cnn/__init__.py +++ b/mmdeploy/mmcv/cnn/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import conv2d_adaptive_padding # noqa: F401,F403 from .transformer import MultiHeadAttentionop __all__ = ['MultiHeadAttentionop'] diff --git a/mmdeploy/mmcv/cnn/conv2d_adaptive_padding.py b/mmdeploy/mmcv/cnn/conv2d_adaptive_padding.py deleted file mode 100644 index d00184c8ee..0000000000 --- a/mmdeploy/mmcv/cnn/conv2d_adaptive_padding.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math - -import torch -import torch.nn.functional as F - -from mmdeploy.core import FUNCTION_REWRITER -from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape - - -def compute_padding(input_size, kernel_size, stride, dilation): - """Compute padding.""" - - input_h, input_w = input_size - kernel_h, kernel_w = kernel_size - stride_h, stride_w = stride - dilation_h, dilation_w = dilation - output_h = math.ceil(input_h / stride_h) - output_w = math.ceil(input_w / stride_w) - pad_h = max( - (output_h - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - input_h, - 0) - pad_w = max( - (output_w - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - input_w, - 0) - if pad_w > 0 or pad_h > 0: - padded = [ - pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 - ] - else: - padded = None - return padded - - -class AdaptivePadOp(torch.autograd.Function): - """Dummy adaptive pad op.""" - - @staticmethod - def forward(ctx, x, padded): - if padded is not None: - x = F.pad(x, padded) - return x - - @staticmethod - def symbolic(g, x, padded): - if padded is None: - return g.op('Identity', x) - padded = g.op( - 'Constant', value_t=torch.tensor(padded, dtype=torch.int64)) - constant_value = g.op( - 'Constant', value_t=torch.tensor(0, dtype=torch.int64)) - return g.op( - 'Pad', x, padded, constant_value, mode_s='constant', outputs=1) - - -@FUNCTION_REWRITER.register_rewriter( - func_name='mmcv.cnn.bricks.conv2d_adaptive_padding. \ - Conv2dAdaptivePadding.forward', - backend=Backend.TENSORRT.value) -def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x): - """Rewrite `forward` of Conv2dAdaptivePadding used in EfficientNet for - TensorRT backend. Main changes of this rewritten function is to separate - the computation of padding and encapsulate it into another - `torch.autograd.Function` so that the adaptive padding could be parsed as - `Pad` ops in ONNX with the padding information computed in advance (Only - for static shape configuration). - - Args: - x (Tensor): Input tensor of Conv2dAdaptivePadding ops - Returns: - Tensor: forward result of 2D convolution after padding - """ - - deploy_cfg = ctx.cfg - is_dynamic_flag = is_dynamic_shape(deploy_cfg) - if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg): - padded = compute_padding(x.shape[2:], self.weight.shape[2:], - self.stride, self.dilation) - if padded is not None: - padded = [int(_) for _ in padded] - x = AdaptivePadOp.apply(x, padded) - return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, - self.dilation, self.groups) - else: - x = ctx.origin_func(x) - return x diff --git a/mmdeploy/mmcv/cnn/transformer.py b/mmdeploy/mmcv/cnn/transformer.py index 58f79657c2..6069f5a43c 100644 --- a/mmdeploy/mmcv/cnn/transformer.py +++ b/mmdeploy/mmcv/cnn/transformer.py @@ -57,8 +57,7 @@ def symbolic(g, q: torch._C.Value, k: torch._C.Value, v: torch._C.Value, @FUNCTION_REWRITER.register_rewriter( func_name='mmcv.cnn.bricks.transformer.MultiheadAttention.forward', backend=Backend.NCNN.value) -def multiheadattention__forward__ncnn(ctx, - self, +def multiheadattention__forward__ncnn(self, query, key=None, value=None, diff --git a/mmdeploy/mmcv/ops/deform_conv.py b/mmdeploy/mmcv/ops/deform_conv.py index 3e2a436f48..fbbc300b8e 100644 --- a/mmdeploy/mmcv/ops/deform_conv.py +++ b/mmdeploy/mmcv/ops/deform_conv.py @@ -4,8 +4,7 @@ @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.deform_conv.DeformConv2dFunction') -def deform_conv__default(ctx, - g, +def deform_conv__default(g, input, offset, weight, @@ -31,8 +30,7 @@ def deform_conv__default(ctx, @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.deform_conv.DeformConv2dFunction', backend='openvino') -def deform_conv_openvino(ctx, - g, +def deform_conv_openvino(g, input, offset, weight, diff --git a/mmdeploy/mmcv/ops/modulated_deform_conv.py b/mmdeploy/mmcv/ops/modulated_deform_conv.py index df3c338a81..64fd9fdd7e 100644 --- a/mmdeploy/mmcv/ops/modulated_deform_conv.py +++ b/mmdeploy/mmcv/ops/modulated_deform_conv.py @@ -4,9 +4,8 @@ @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.modulated_deform_conv.ModulatedDeformConv2dFunction') -def modulated_deform_conv_default(ctx, g, input, offset, mask, weight, bias, - stride, padding, dilation, groups, - deform_groups): +def modulated_deform_conv_default(g, input, offset, mask, weight, bias, stride, + padding, dilation, groups, deform_groups): """Rewrite mdcn symbolic function for all backend.""" input_tensors = [input, offset, mask, weight] if bias is not None: diff --git a/mmdeploy/mmcv/ops/nms.py b/mmdeploy/mmcv/ops/nms.py index 1ab303bbaa..89ca2d604c 100644 --- a/mmdeploy/mmcv/ops/nms.py +++ b/mmdeploy/mmcv/ops/nms.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from mmcv.ops import nms from torch import Tensor from torch.onnx import symbolic_helper as sym_help -from mmdeploy.core import SYMBOLIC_REWRITER - class ONNXNMSop(torch.autograd.Function): """Create onnx::NonMaxSuppression op. @@ -34,7 +33,6 @@ def forward(ctx, boxes: Tensor, scores: Tensor, (num_selected_indices, 3) with each row of [batch_index, class_index, box_index]. """ - from mmcv.ops import nms batch_size, num_class, _ = scores.shape score_threshold = float(score_threshold) @@ -78,57 +76,23 @@ def symbolic(g, boxes: Tensor, scores: Tensor, Returns: NonMaxSuppression op for onnx. """ - return g.op( - 'NonMaxSuppression', - boxes, - scores, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - outputs=1) - - -@SYMBOLIC_REWRITER.register_symbolic( - 'mmdeploy.mmcv.ops.ONNXNMSop', backend='default') -def nms_dynamic(ctx, g, boxes: Tensor, scores: Tensor, - max_output_boxes_per_class: int, iou_threshold: float, - score_threshold: float): - """Rewrite symbolic function for default backend. - - Support max_output_boxes_per_class, iou_threshold, score_threshold of - constant Tensor, which is aligned with ONNX's nms op. - - Args: - ctx (ContextCaller): The context with additional information. - g (Graph): The traced onnx graph. - boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. - scores (Tensor): The detection scores of shape - [N, num_boxes, num_classes]. - max_output_boxes_per_class (int): Maximum number of output - boxes per class of nms. - iou_threshold (float): IOU threshold of nms. - score_threshold (float): score threshold of nms. - - Returns: - NonMaxSuppression op for onnx. - """ - - if not sym_help._is_value(max_output_boxes_per_class): - max_output_boxes_per_class = g.op( - 'Constant', - value_t=torch.tensor(max_output_boxes_per_class, dtype=torch.long)) - - if not sym_help._is_value(iou_threshold): - iou_threshold = g.op( - 'Constant', - value_t=torch.tensor([iou_threshold], dtype=torch.float)) - - if not sym_help._is_value(score_threshold): - score_threshold = g.op( - 'Constant', - value_t=torch.tensor([score_threshold], dtype=torch.float)) - return g.op('NonMaxSuppression', boxes, scores, max_output_boxes_per_class, - iou_threshold, score_threshold) + if not sym_help._is_value(max_output_boxes_per_class): + max_output_boxes_per_class = g.op( + 'Constant', + value_t=torch.tensor( + max_output_boxes_per_class, dtype=torch.long)) + + if not sym_help._is_value(iou_threshold): + iou_threshold = g.op( + 'Constant', + value_t=torch.tensor([iou_threshold], dtype=torch.float)) + + if not sym_help._is_value(score_threshold): + score_threshold = g.op( + 'Constant', + value_t=torch.tensor([score_threshold], dtype=torch.float)) + return g.op('NonMaxSuppression', boxes, scores, + max_output_boxes_per_class, iou_threshold, score_threshold) class TRTBatchedNMSop(torch.autograd.Function): diff --git a/mmdeploy/mmcv/ops/point_sample.py b/mmdeploy/mmcv/ops/point_sample.py index 7f2e43ecfa..8051b708dd 100644 --- a/mmdeploy/mmcv/ops/point_sample.py +++ b/mmdeploy/mmcv/ops/point_sample.py @@ -6,7 +6,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmcv.ops.point_sample', backend='default') -def point_sample__default(ctx, input, points, align_corners=False, **kwargs): +def point_sample__default(input, points, align_corners=False, **kwargs): """A wrapper around :func:`grid_sample` to support 3D point_coords tensors Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to lie inside ``[0, 1] x [0, 1]`` square. @@ -37,7 +37,7 @@ def point_sample__default(ctx, input, points, align_corners=False, **kwargs): @FUNCTION_REWRITER.register_rewriter( func_name='mmcv.ops.SimpleRoIAlign.forward') -def simple_roialign__forward(ctx, self, features, rois): +def simple_roialign__forward(self, features, rois): """Rewrite `forward` of SimpleRoIAlign. Args: diff --git a/mmdeploy/mmcv/ops/roi_align.py b/mmdeploy/mmcv/ops/roi_align.py index 7e99ef414f..6ee901a047 100644 --- a/mmdeploy/mmcv/ops/roi_align.py +++ b/mmdeploy/mmcv/ops/roi_align.py @@ -13,9 +13,9 @@ # visible in mmcv. @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.roi_align.__self__', backend='default') -def roi_align_default(ctx, g, input: Tensor, rois: Tensor, - output_size: List[int], spatial_scale: float, - sampling_ratio: int, pool_mode: str, aligned: bool): +def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int], + spatial_scale: float, sampling_ratio: int, + pool_mode: str, aligned: bool): """Rewrite symbolic function for default backend. Replace onnx::RoiAlign with mmcv::MMCVRoiAlign for PPLNN. For ONNXRuntime, @@ -41,6 +41,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor, Returns: MMCVRoiAlign op for onnx. """ + ctx = SYMBOLIC_REWRITER.get_context() backend = get_backend(ctx.cfg) if backend == Backend.PPLNN or backend == Backend.TENSORRT: domain = 'mmcv' @@ -56,17 +57,17 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor, aligned_i=aligned) else: from torch.onnx.symbolic_opset9 import _cast_Long - from torch.onnx.symbolic_opset11 import add, select, squeeze + from torch.onnx.symbolic_opset11 import add, select batch_indices = _cast_Long( g, - squeeze( - g, + g.op( + 'Squeeze', select( g, rois, 1, g.op( 'Constant', - value_t=torch.tensor([0], dtype=torch.long))), 1), - False) + value_t=torch.tensor([0], dtype=torch.long))), + axes_i=[1]), False) rois = select( g, rois, 1, g.op( diff --git a/mmdeploy/mmcv/ops/roi_align_rotated.py b/mmdeploy/mmcv/ops/roi_align_rotated.py index f7707071da..90c2e0414d 100644 --- a/mmdeploy/mmcv/ops/roi_align_rotated.py +++ b/mmdeploy/mmcv/ops/roi_align_rotated.py @@ -11,7 +11,7 @@ # is not visible in mmcv. @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.roi_align_rotated.__self__', backend='default') -def roi_align_rotated_default(ctx, g, input: Tensor, rois: Tensor, +def roi_align_rotated_default(g, input: Tensor, rois: Tensor, output_size: List[int], spatial_scale: float, sampling_ratio: int, aligned: bool, clockwise: bool): diff --git a/mmdeploy/mmcv/ops/transformer.py b/mmdeploy/mmcv/ops/transformer.py index 53f7f550b9..bf020cce24 100644 --- a/mmdeploy/mmcv/ops/transformer.py +++ b/mmdeploy/mmcv/ops/transformer.py @@ -6,7 +6,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmcv.cnn.bricks.transformer.PatchEmbed.forward', backend=Backend.NCNN.value) -def patch_embed__forward__ncnn(ctx, self, x): +def patch_embed__forward__ncnn(self, x): """Rewrite `forward` of PatchEmbed for ncnn backend. Args: diff --git a/mmdeploy/pytorch/functions/adaptive_pool.py b/mmdeploy/pytorch/functions/adaptive_pool.py index fb09cd82e4..14a185ed87 100644 --- a/mmdeploy/pytorch/functions/adaptive_pool.py +++ b/mmdeploy/pytorch/functions/adaptive_pool.py @@ -9,8 +9,9 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.adaptive_avg_pool2d') -def adaptive_avg_pool2d__default(ctx, input, output_size): +def adaptive_avg_pool2d__default(input, output_size): """Rewrite `adaptive_avg_pool2d` for default backend.""" + ctx = FUNCTION_REWRITER.get_context() output_size = _pair(output_size) if int(output_size[0]) == int(output_size[1]) == 1: out = ctx.origin_func(input, output_size) @@ -39,6 +40,7 @@ def adaptive_avg_pool2d__default(ctx, input, output_size): @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.adaptive_avg_pool2d', backend=Backend.TORCHSCRIPT.value) -def adaptive_avg_pool2d__ncnn(ctx, input, output_size): +def adaptive_avg_pool2d__ncnn(input, output_size): + ctx = FUNCTION_REWRITER.get_context() """Rewrite `adaptive_avg_pool2d` for ncnn and torchscript backend.""" return ctx.origin_func(input, output_size) diff --git a/mmdeploy/pytorch/functions/atan2.py b/mmdeploy/pytorch/functions/atan2.py index a09986a8f2..90ce5d63d8 100644 --- a/mmdeploy/pytorch/functions/atan2.py +++ b/mmdeploy/pytorch/functions/atan2.py @@ -7,7 +7,6 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.atan2', backend='default') def atan2__default( - ctx, input1: torch.Tensor, input2: torch.Tensor, ): diff --git a/mmdeploy/pytorch/functions/chunk.py b/mmdeploy/pytorch/functions/chunk.py index 98ad1b2eff..29677b2ead 100644 --- a/mmdeploy/pytorch/functions/chunk.py +++ b/mmdeploy/pytorch/functions/chunk.py @@ -7,7 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.chunk', backend='ncnn') -def chunk__ncnn(ctx, self, num_chunks: int, dim: int = 0) -> torch.Tensor: +def chunk__ncnn(self, num_chunks: int, dim: int = 0) -> torch.Tensor: """Rewrite `chunk` for NCNN backend. Chunk in ncnn are not supported, so it should be rewritten. @@ -36,10 +36,7 @@ def chunk__ncnn(ctx, self, num_chunks: int, dim: int = 0) -> torch.Tensor: @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.chunk', ir=IR.TORCHSCRIPT) -def chunk__torchscript(ctx, - self, - num_chunks: int, - dim: int = 0) -> torch.Tensor: +def chunk__torchscript(self, num_chunks: int, dim: int = 0) -> torch.Tensor: """Rewrite `chunk` for Torchscript. Replace chunk op with split op diff --git a/mmdeploy/pytorch/functions/clip.py b/mmdeploy/pytorch/functions/clip.py index 88a9b64894..c550358f49 100644 --- a/mmdeploy/pytorch/functions/clip.py +++ b/mmdeploy/pytorch/functions/clip.py @@ -13,11 +13,12 @@ func_name='torch.Tensor.clamp', backend=Backend.COREML.value) @FUNCTION_REWRITER.register_rewriter( func_name='torch.clamp', backend=Backend.COREML.value) -def clip__coreml(ctx, input, min=None, max=None, **kwargs) -> torch.Tensor: +def clip__coreml(input, min=None, max=None, **kwargs) -> torch.Tensor: """Rewrite `clip` for coreml backend. Cast data type. """ + ctx = FUNCTION_REWRITER.get_context() if min is not None and not isinstance(min, torch.Tensor): min = input.new_tensor(min) diff --git a/mmdeploy/pytorch/functions/expand.py b/mmdeploy/pytorch/functions/expand.py index 0ae90f8a4a..c2a1aba70e 100644 --- a/mmdeploy/pytorch/functions/expand.py +++ b/mmdeploy/pytorch/functions/expand.py @@ -6,11 +6,12 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.expand', backend='ncnn') -def expand__ncnn(ctx, self, *sizes) -> torch.Tensor: +def expand__ncnn(self, *sizes) -> torch.Tensor: """Rewrite `expand` for NCNN backend. Do not expand on batch dim for tensor with ndim >= 3 """ + ctx = FUNCTION_REWRITER.get_context() if self.ndim < 3 or sizes[0] not in [1, -1]: return ctx.origin_func(*sizes) return self diff --git a/mmdeploy/pytorch/functions/flatten.py b/mmdeploy/pytorch/functions/flatten.py index d8d40dd54c..7270f32fd4 100644 --- a/mmdeploy/pytorch/functions/flatten.py +++ b/mmdeploy/pytorch/functions/flatten.py @@ -13,7 +13,7 @@ func_name='torch.Tensor.flatten', backend=Backend.COREML.value) @FUNCTION_REWRITER.register_rewriter( func_name='torch.flatten', backend=Backend.COREML.value) -def flatten__coreml(ctx, input, start_dim=0, end_dim=-1) -> torch.Tensor: +def flatten__coreml(input, start_dim=0, end_dim=-1) -> torch.Tensor: """Rewrite `flatten` for coreml backend. Use reshape instead of flatten diff --git a/mmdeploy/pytorch/functions/getattribute.py b/mmdeploy/pytorch/functions/getattribute.py index 8447aca8b7..74e9bfa0b8 100644 --- a/mmdeploy/pytorch/functions/getattribute.py +++ b/mmdeploy/pytorch/functions/getattribute.py @@ -6,13 +6,14 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.__getattribute__', backend='ncnn') -def tensor__getattribute__ncnn(ctx, self: torch.Tensor, name: str): +def tensor__getattribute__ncnn(self: torch.Tensor, name: str): """Rewrite `__getattribute__` of `torch.Tensor` for ncnn backend. Shape node is not supported by ncnn. This function transform dynamic shape to constant shape. """ + ctx = FUNCTION_REWRITER.get_context() ret = ctx.origin_func(self, name) if name == 'shape': ret = torch.Size([int(s) for s in ret]) diff --git a/mmdeploy/pytorch/functions/group_norm.py b/mmdeploy/pytorch/functions/group_norm.py index 393fd720d4..25fe2b98ad 100644 --- a/mmdeploy/pytorch/functions/group_norm.py +++ b/mmdeploy/pytorch/functions/group_norm.py @@ -9,7 +9,6 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.group_norm', backend='ncnn') def group_norm__ncnn( - ctx, input: torch.Tensor, num_groups: int, weight: Union[torch.Tensor, torch.NoneType] = None, diff --git a/mmdeploy/pytorch/functions/interpolate.py b/mmdeploy/pytorch/functions/interpolate.py index a335792f0c..39424b8a39 100644 --- a/mmdeploy/pytorch/functions/interpolate.py +++ b/mmdeploy/pytorch/functions/interpolate.py @@ -10,8 +10,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.interpolate', backend='ncnn') -def interpolate__ncnn(ctx, - input: torch.Tensor, +def interpolate__ncnn(input: torch.Tensor, size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]] = None, scale_factor: Optional[Union[float, @@ -24,6 +23,7 @@ def interpolate__ncnn(ctx, ncnn require `size` should be constant in ONNX Node. We use `scale_factor` instead of `size` to avoid dynamic size. """ + ctx = FUNCTION_REWRITER.get_context() input_size = input.shape if scale_factor is None: @@ -42,8 +42,7 @@ def interpolate__ncnn(ctx, @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.interpolate', backend='rknn') -def interpolate__rknn(ctx, - input: torch.Tensor, +def interpolate__rknn(input: torch.Tensor, size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]] = None, scale_factor: Optional[Union[float, @@ -56,6 +55,7 @@ def interpolate__rknn(ctx, rknn require `size` should be constant in ONNX Node. We use `scale_factor` instead of `size` to avoid dynamic size. """ + ctx = FUNCTION_REWRITER.get_context() input_size = input.shape if scale_factor is None: scale_factor = [(s_out / s_in) @@ -77,7 +77,6 @@ def interpolate__rknn(ctx, is_pytorch=True, backend=Backend.TENSORRT.value) def interpolate__tensorrt( - ctx, input: torch.Tensor, size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]] = None, @@ -87,6 +86,7 @@ def interpolate__tensorrt( recompute_scale_factor: Optional[bool] = None, ): """Register default symbolic function for `interpolate`.""" + ctx = FUNCTION_REWRITER.get_context() class BicubicInterpolate(Function): diff --git a/mmdeploy/pytorch/functions/linear.py b/mmdeploy/pytorch/functions/linear.py index 7cfb4735ac..616fef732a 100644 --- a/mmdeploy/pytorch/functions/linear.py +++ b/mmdeploy/pytorch/functions/linear.py @@ -30,7 +30,6 @@ def symbolic(g, input, weight, bias=None): @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.linear', backend='ncnn') def linear__ncnn( - ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[Union[torch.Tensor, torch.NoneType]] = None, @@ -41,6 +40,7 @@ def linear__ncnn( add extra reshape and transpose to support linear operation of different input shape. """ + ctx = FUNCTION_REWRITER.get_context() origin_func = ctx.origin_func dim = input.dim() diff --git a/mmdeploy/pytorch/functions/masked_fill.py b/mmdeploy/pytorch/functions/masked_fill.py index 5e4f67b45e..bd8cd7b6c2 100644 --- a/mmdeploy/pytorch/functions/masked_fill.py +++ b/mmdeploy/pytorch/functions/masked_fill.py @@ -13,13 +13,14 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.masked_fill', backend=Backend.ONNXRUNTIME.value) def masked_fill__onnxruntime( - ctx, input, mask: torch.Tensor, value: Union[torch.Tensor, - Number]) -> torch.Tensor: + input, mask: torch.Tensor, value: Union[torch.Tensor, + Number]) -> torch.Tensor: """Rewrite `masked_fill` for onnxruntime backend. SATRN model as example, when value is set to `float('-inf')`, the results of ORT inferencing turns out to be NAN. """ + ctx = FUNCTION_REWRITER.get_context() if value == float('-inf'): value = -1e34 # hard coding number return ctx.origin_func(input, mask, value) diff --git a/mmdeploy/pytorch/functions/mod.py b/mmdeploy/pytorch/functions/mod.py index e6bb1cb51c..bd1bd77d12 100644 --- a/mmdeploy/pytorch/functions/mod.py +++ b/mmdeploy/pytorch/functions/mod.py @@ -11,10 +11,11 @@ # TODO add version control when MOD is supported by TensorRT @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.__mod__', backend=Backend.TENSORRT.value) -def mod__tensorrt(ctx, input: torch.Tensor, other: Union[torch.Tensor, - torch.NumberType], - *args, **kwargs) -> torch.Tensor: +def mod__tensorrt(input: torch.Tensor, other: Union[torch.Tensor, + torch.NumberType], *args, + **kwargs) -> torch.Tensor: """Rewrite `mod` when exporting model to ONNX for TensorRT backend.""" + ctx = FUNCTION_REWRITER.get_context() if version.parse(torch.__version__) > version.parse('1.10.0'): return input - (input // other) * other return ctx.origin_func(input, other, *args, **kwargs) diff --git a/mmdeploy/pytorch/functions/multi_head_attention_forward.py b/mmdeploy/pytorch/functions/multi_head_attention_forward.py index 8d165649c5..fadfc3e91d 100644 --- a/mmdeploy/pytorch/functions/multi_head_attention_forward.py +++ b/mmdeploy/pytorch/functions/multi_head_attention_forward.py @@ -46,7 +46,6 @@ def symbolic(g, q, k, v, mask): func_name='torch.nn.functional._scaled_dot_product_attention', backend=Backend.TENSORRT.value) def _scaled_dot_product_attention__tensorrt( - ctx, q: Tensor, k: Tensor, v: Tensor, diff --git a/mmdeploy/pytorch/functions/normalize.py b/mmdeploy/pytorch/functions/normalize.py index a676439cdd..b0ae4ccfe2 100644 --- a/mmdeploy/pytorch/functions/normalize.py +++ b/mmdeploy/pytorch/functions/normalize.py @@ -7,8 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.normalize', backend='ncnn') -def normalize__ncnn(ctx, - input: torch.Tensor, +def normalize__ncnn(input: torch.Tensor, p: int = 2, dim: int = 1, eps: float = 1e-12, @@ -18,6 +17,7 @@ def normalize__ncnn(ctx, Make sure L2 norm on channel dim and be exported to ncnn correctly. """ + ctx = FUNCTION_REWRITER.get_context() if dim < 0: dim += input.ndim assert dim != 0, 'Should not normalize on batch index' diff --git a/mmdeploy/pytorch/functions/pad.py b/mmdeploy/pytorch/functions/pad.py index 7f24785e8b..82274d5c06 100644 --- a/mmdeploy/pytorch/functions/pad.py +++ b/mmdeploy/pytorch/functions/pad.py @@ -11,7 +11,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.onnx.symbolic_opset11._prepare_onnx_paddings', backend='tensorrt') -def _prepare_onnx_paddings__tensorrt(ctx, g, input, pad): +def _prepare_onnx_paddings__tensorrt(g, input, pad): """Rewrite `_prepare_onnx_paddings` for TensorRT backend. For codes like `x = torch.nn.ZeroPad2d((0, a, 0, b))(x)`, where a and b are @@ -26,6 +26,7 @@ def _prepare_onnx_paddings__tensorrt(ctx, g, input, pad): ..., dim_m_begin, dim_m_end, where m is in range [0, n]. """ + ctx = FUNCTION_REWRITER.get_context() torch_version = version_parse(torch.__version__) if torch_version.major == 1 and torch_version.minor < 10: return ctx.origin_func(g, input, pad) diff --git a/mmdeploy/pytorch/functions/repeat.py b/mmdeploy/pytorch/functions/repeat.py index fa528c33f5..edb6efc3a5 100644 --- a/mmdeploy/pytorch/functions/repeat.py +++ b/mmdeploy/pytorch/functions/repeat.py @@ -8,13 +8,14 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.repeat', backend='tensorrt') -def tensor__repeat__tensorrt(ctx, input: torch.Tensor, - *size: Union[torch.Size, Sequence[int]]): +def tensor__repeat__tensorrt(input: torch.Tensor, *size: Union[torch.Size, + Sequence[int]]): """Rewrite `repeat` for TensorRT backend. Some layers in TensorRT can not be applied on batch axis. add extra axis before operation and remove it afterward. """ + ctx = FUNCTION_REWRITER.get_context() origin_func = ctx.origin_func if input.dim() == 1 and len(size) == 1: diff --git a/mmdeploy/pytorch/functions/size.py b/mmdeploy/pytorch/functions/size.py index 30ead981ab..8325f115cd 100644 --- a/mmdeploy/pytorch/functions/size.py +++ b/mmdeploy/pytorch/functions/size.py @@ -6,12 +6,13 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.size', backend='ncnn') -def tensor__size__ncnn(ctx, self, *args): +def tensor__size__ncnn(self, *args): """Rewrite `size` for ncnn backend. ONNX Shape node is not supported in ncnn. This function return integer instead of Torch.Size to avoid ONNX Shape node. """ + ctx = FUNCTION_REWRITER.get_context() ret = ctx.origin_func(self, *args) if isinstance(ret, torch.Tensor): @@ -26,11 +27,12 @@ def tensor__size__ncnn(ctx, self, *args): @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.size', backend='ascend') -def tensor__size__ascend(ctx, self, *args): +def tensor__size__ascend(self, *args): """Rewrite `size` for ascens backend. Support negative index. """ + ctx = FUNCTION_REWRITER.get_context() if len(args) != 0: index = args[0] diff --git a/mmdeploy/pytorch/functions/tensor_getitem.py b/mmdeploy/pytorch/functions/tensor_getitem.py index 7454a5a6d7..17187eeb94 100644 --- a/mmdeploy/pytorch/functions/tensor_getitem.py +++ b/mmdeploy/pytorch/functions/tensor_getitem.py @@ -8,11 +8,12 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.__getitem__', backend='ascend') -def tensor__getitem__ascend(ctx, self, key) -> torch.Tensor: +def tensor__getitem__ascend(self, key) -> torch.Tensor: """Rewrite `getitem` for ascend backend. Ascend does not support negative select """ + ctx = FUNCTION_REWRITER.get_context() if not isinstance(key, (tuple, list)): if isinstance(key, int) and key < 0: key = self.dim() + key diff --git a/mmdeploy/pytorch/functions/tensor_setitem.py b/mmdeploy/pytorch/functions/tensor_setitem.py index 6795bc2415..4860bbe143 100644 --- a/mmdeploy/pytorch/functions/tensor_setitem.py +++ b/mmdeploy/pytorch/functions/tensor_setitem.py @@ -8,8 +8,9 @@ @FUNCTION_REWRITER.register_rewriter(func_name='torch.Tensor.__setitem__') -def tensor__setitem__default(ctx, self, key, value): +def tensor__setitem__default(self, key, value): """Rewrite `setitem` to ease the index put.""" + ctx = FUNCTION_REWRITER.get_context() # only support torch>=1.9.0 if parse(torch.__version__) < parse('1.9.0'): @@ -76,5 +77,5 @@ def tensor__setitem__default(ctx, self, key, value): if parse(torch.__version__) >= parse('1.12.0'): @SYMBOLIC_REWRITER.register_symbolic('copy', is_pytorch=True) - def copy__default(ctx, g, x, y, non_blocking): + def copy__default(g, x, y, non_blocking): return x diff --git a/mmdeploy/pytorch/functions/topk.py b/mmdeploy/pytorch/functions/topk.py index 82569250b5..38dac19786 100644 --- a/mmdeploy/pytorch/functions/topk.py +++ b/mmdeploy/pytorch/functions/topk.py @@ -10,8 +10,7 @@ @FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='default') @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.topk', backend='default') -def topk__dynamic(ctx, - input: torch.Tensor, +def topk__dynamic(input: torch.Tensor, k: int, dim: Optional[int] = None, largest: bool = True, @@ -20,6 +19,7 @@ def topk__dynamic(ctx, Cast k to tensor and makesure k is smaller than input.shape[dim]. """ + ctx = FUNCTION_REWRITER.get_context() if dim is None: dim = int(input.ndim - 1) @@ -37,8 +37,7 @@ def topk__dynamic(ctx, func_name='torch.topk', backend='tensorrt') @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.topk', backend='tensorrt') -def topk__tensorrt(ctx, - input: torch.Tensor, +def topk__tensorrt(input: torch.Tensor, k: int, dim: Optional[int] = None, largest: bool = True, @@ -48,6 +47,7 @@ def topk__tensorrt(ctx, TensorRT does not support topk with dynamic k. This function cast k to constant integer. """ + ctx = FUNCTION_REWRITER.get_context() # https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#topKsetup from mmdeploy.utils.constants import TENSORRT_MAX_TOPK diff --git a/mmdeploy/pytorch/functions/triu.py b/mmdeploy/pytorch/functions/triu.py index 025b2029ff..e7e8e501ee 100644 --- a/mmdeploy/pytorch/functions/triu.py +++ b/mmdeploy/pytorch/functions/triu.py @@ -5,8 +5,7 @@ @FUNCTION_REWRITER.register_rewriter(func_name='torch.triu') -def triu__default(ctx, - input: torch.Tensor, +def triu__default(input: torch.Tensor, diagonal: int = 0, *args, **kwargs) -> torch.Tensor: diff --git a/mmdeploy/pytorch/symbolics/adaptive_pool.py b/mmdeploy/pytorch/symbolics/adaptive_pool.py index d27049576b..a3461313a1 100644 --- a/mmdeploy/pytorch/symbolics/adaptive_pool.py +++ b/mmdeploy/pytorch/symbolics/adaptive_pool.py @@ -5,7 +5,7 @@ @SYMBOLIC_REWRITER.register_symbolic( 'adaptive_avg_pool2d', is_pytorch=True, backend='ncnn') -def adaptive_avg_pool2d__ncnn(ctx, g, x, output_size): +def adaptive_avg_pool2d__ncnn(g, x, output_size): """Register ncnn symbolic function for `adaptive_avg_pool2d`. Align symbolic of adaptive_avg_pool2d in ncnn. diff --git a/mmdeploy/pytorch/symbolics/gelu.py b/mmdeploy/pytorch/symbolics/gelu.py index 039e5a1147..3d9131181e 100644 --- a/mmdeploy/pytorch/symbolics/gelu.py +++ b/mmdeploy/pytorch/symbolics/gelu.py @@ -1,11 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. +from torch.onnx import symbolic_helper from mmdeploy.core import SYMBOLIC_REWRITER from mmdeploy.utils import Backend +@symbolic_helper.parse_args('v') +def gelu__ncnn_pt111(g, self): + """gelu for torch<=1.12.""" + return g.op('mmdeploy::Gelu', self) + + @SYMBOLIC_REWRITER.register_symbolic( - 'gelu', is_pytorch=True, arg_descriptors=['v'], backend=Backend.NCNN.value) -def gelu__ncnn(ctx, g, self): + 'gelu', is_pytorch=True, backend=Backend.NCNN.value) +def gelu__ncnn(g, self, approximate: str = 'none'): """Support export GELU with ncnn backend.""" - return g.op('mmdeploy::Gelu', self) + return gelu__ncnn_pt111(g, self) diff --git a/mmdeploy/pytorch/symbolics/grid_sampler.py b/mmdeploy/pytorch/symbolics/grid_sampler.py index b6fdbf5c7a..0e3e105106 100644 --- a/mmdeploy/pytorch/symbolics/grid_sampler.py +++ b/mmdeploy/pytorch/symbolics/grid_sampler.py @@ -50,11 +50,12 @@ def grid_sampler_ppl(g, @SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True) -def grid_sampler__default(ctx, *args): +def grid_sampler__default(*args): """Register default symbolic function for `grid_sampler`. Add support to grid_sample to ONNX. """ + ctx = SYMBOLIC_REWRITER.get_context() backend = get_backend(ctx.cfg) if backend == Backend.PPLNN: return grid_sampler_ppl(*args) diff --git a/mmdeploy/pytorch/symbolics/hardsigmoid.py b/mmdeploy/pytorch/symbolics/hardsigmoid.py index a4d14173ed..27561685ed 100644 --- a/mmdeploy/pytorch/symbolics/hardsigmoid.py +++ b/mmdeploy/pytorch/symbolics/hardsigmoid.py @@ -6,7 +6,7 @@ @SYMBOLIC_REWRITER.register_symbolic( 'hardsigmoid', is_pytorch=True, arg_descriptors=['v']) -def hardsigmoid__default(ctx, g, self): +def hardsigmoid__default(g, self): """Support export hardsigmoid This rewrite enable export hardsigmoid in torch<=1.8.2.""" return g.op('HardSigmoid', self, alpha_f=1 / 6) diff --git a/mmdeploy/pytorch/symbolics/instance_norm.py b/mmdeploy/pytorch/symbolics/instance_norm.py index c04e42528a..06d287574b 100644 --- a/mmdeploy/pytorch/symbolics/instance_norm.py +++ b/mmdeploy/pytorch/symbolics/instance_norm.py @@ -64,7 +64,7 @@ def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): @SYMBOLIC_REWRITER.register_symbolic( 'group_norm', backend='tensorrt', is_pytorch=True) -def instance_norm__tensorrt(ctx, *args): +def instance_norm__tensorrt(*args): """Register symbolic function for TensorRT backend. Notes: diff --git a/mmdeploy/pytorch/symbolics/layer_norm.py b/mmdeploy/pytorch/symbolics/layer_norm.py index 94ea0169a4..854ef5fd7f 100644 --- a/mmdeploy/pytorch/symbolics/layer_norm.py +++ b/mmdeploy/pytorch/symbolics/layer_norm.py @@ -12,7 +12,7 @@ 'layer_norm', is_pytorch=True, arg_descriptors=['v', 'is', 'v', 'v', 'f', 'i']) -def layer_norm__default(ctx, g, input, normalized_shape, weight, bias, eps, +def layer_norm__default(g, input, normalized_shape, weight, bias, eps, cudnn_enable): """Symbolic function for `layer_norm` @@ -62,7 +62,7 @@ def _layer_norm_ncnn(g, input, normalized_shape, weight, bias, eps, @SYMBOLIC_REWRITER.register_symbolic( 'layer_norm', is_pytorch=True, backend=Backend.NCNN.value) -def layer_norm__ncnn(ctx, *args): +def layer_norm__ncnn(*args): """Register default symbolic function for `layer_norm`. Add support to layer_norm to ONNX. diff --git a/mmdeploy/pytorch/symbolics/linear.py b/mmdeploy/pytorch/symbolics/linear.py index 8cb997b400..3236d71bf9 100644 --- a/mmdeploy/pytorch/symbolics/linear.py +++ b/mmdeploy/pytorch/symbolics/linear.py @@ -36,7 +36,7 @@ def linear_normal(g, input, weight, bias): @SYMBOLIC_REWRITER.register_symbolic( 'linear', is_pytorch=True, backend=Backend.NCNN.value) -def linear__ncnn(ctx, g, input, weight, bias): +def linear__ncnn(g, input, weight, bias): """Support export linear This rewrite enable export Gemm.""" if bias is None: return linear_no_bias(g, input, weight) diff --git a/mmdeploy/pytorch/symbolics/lstm.py b/mmdeploy/pytorch/symbolics/lstm.py index 3b8926186f..2316ef28be 100644 --- a/mmdeploy/pytorch/symbolics/lstm.py +++ b/mmdeploy/pytorch/symbolics/lstm.py @@ -13,8 +13,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torch.onnx.symbolic_opset9._generic_rnn', backend='ncnn') -def generic_rnn__ncnn(ctx, - g, +def generic_rnn__ncnn(g, variant, input, initial_states, diff --git a/mmdeploy/pytorch/symbolics/roll.py b/mmdeploy/pytorch/symbolics/roll.py index 34b8920458..7151990d10 100644 --- a/mmdeploy/pytorch/symbolics/roll.py +++ b/mmdeploy/pytorch/symbolics/roll.py @@ -28,6 +28,6 @@ def roll(g, self, shifts, dims): @SYMBOLIC_REWRITER.register_symbolic('roll', is_pytorch=True) -def roll_default(ctx, g, self, shifts, dims): +def roll_default(g, self, shifts, dims): """Support export roll to ONNX with PyTorch version 1.10-.""" return roll(g, self, shifts, dims) diff --git a/mmdeploy/pytorch/symbolics/squeeze.py b/mmdeploy/pytorch/symbolics/squeeze.py index ffcac55be4..1484fa3dc7 100644 --- a/mmdeploy/pytorch/symbolics/squeeze.py +++ b/mmdeploy/pytorch/symbolics/squeeze.py @@ -5,7 +5,7 @@ @SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True) -def squeeze__default(ctx, g, self, dim=None): +def squeeze__default(g, self, dim=None): """Register default symbolic function for `squeeze`. squeeze might be exported with IF node in ONNX, which is not supported in diff --git a/tests/test_core/test_function_rewriter.py b/tests/test_core/test_function_rewriter.py index ca7a681c33..d4e33a1857 100644 --- a/tests/test_core/test_function_rewriter.py +++ b/tests/test_core/test_function_rewriter.py @@ -16,9 +16,10 @@ def test_function_rewriter(): func_name='torch.mul', backend='tensorrt') @FUNCTION_REWRITER.register_rewriter( func_name='torch.add', backend='tensorrt') - def sub_func(rewriter, x, y): - assert hasattr(rewriter, 'cfg') - assert hasattr(rewriter, 'origin_func') + def sub_func(x, y): + ctx = FUNCTION_REWRITER.get_context('torch.add') + assert hasattr(ctx, 'cfg') + assert hasattr(ctx, 'origin_func') return x - y cfg = dict() @@ -42,7 +43,7 @@ def sub_func(rewriter, x, y): # test different config @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.add', backend='default') - def mul_func_class(rewriter, x, y): + def mul_func_class(x, y): return x * y with RewriterContext(cfg, backend='tensorrt'): @@ -62,8 +63,9 @@ def mul_func_class(rewriter, x, y): # test origin_func @FUNCTION_REWRITER.register_rewriter( func_name='torch.add', backend='default') - def origin_add_func(rewriter, x, y, **kwargs): - return rewriter.origin_func(x, y, **kwargs) + 1 + def origin_add_func(x, y, **kwargs): + ctx = FUNCTION_REWRITER.get_context('torch.add') + return ctx.origin_func(x, y, **kwargs) + 1 with RewriterContext(cfg): result = torch.add(x, y) @@ -79,7 +81,7 @@ def test_rewrite_empty_function(): function_rewriter = FunctionRewriter() @function_rewriter.register_rewriter(func_name='torch.abcdefghijklmn') - def func(rewriter, x, y): + def func(x, y): return x + y function_rewriter.enter() @@ -101,12 +103,12 @@ def test_rewrite_homonymic_methods(self): assert c.method() == 1 @function_rewriter.register_rewriter(func_name=path1) - def func_2(ctx, self): + def func_2(self): return 2 @function_rewriter.register_rewriter( func_name=path2, backend=Backend.NCNN.value) - def func_3(ctx, self): + def func_3(self): return 3 function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) @@ -119,11 +121,11 @@ def func_3(ctx, self): @function_rewriter2.register_rewriter( func_name=path1, backend=Backend.NCNN.value) - def func_4(ctx, self): + def func_4(self): return 4 @function_rewriter2.register_rewriter(func_name=path2) - def func_5(ctx, self): + def func_5(self): return 5 function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) @@ -147,12 +149,12 @@ def test_rewrite_derived_methods(): function_rewriter = FunctionRewriter() @function_rewriter.register_rewriter(func_name=path1) - def func_2(ctx, self): + def func_2(self): return 2 @function_rewriter.register_rewriter( func_name=path2, backend=Backend.NCNN.value) - def func_3(ctx, self): + def func_3(self): return 3 function_rewriter.enter(env=collect_env(Backend.DEFAULT, ir=IR.DEFAULT)) diff --git a/tests/test_core/test_symbolic_register.py b/tests/test_core/test_symbolic_register.py index b012f6a8b4..96bebfddf9 100644 --- a/tests/test_core/test_symbolic_register.py +++ b/tests/test_core/test_symbolic_register.py @@ -40,18 +40,19 @@ def test_symbolic_rewriter(): @SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc', backend='ncnn') @SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc') - def symbolic_testfunc_default(symbolic_wrapper, g, x, val): - assert hasattr(symbolic_wrapper, 'cfg') + def symbolic_testfunc_default(g, x, val): + ctx = SYMBOLIC_REWRITER.get_context('mmdeploy.TestFunc') + assert hasattr(ctx, 'cfg') return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val) @SYMBOLIC_REWRITER.register_symbolic( 'mmdeploy.TestFunc', backend='tensorrt') - def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val): + def symbolic_testfunc_tensorrt(g, x, val): return g.op('mmdeploy::symbolic_testfunc_tensorrt', x, val_i=val) @SYMBOLIC_REWRITER.register_symbolic( 'cummax', is_pytorch=True, arg_descriptors=['v', 'i']) - def symbolic_cummax(symbolic_wrapper, g, input, dim): + def symbolic_cummax(g, input, dim): return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2) class TestModel(torch.nn.Module): @@ -103,12 +104,12 @@ def test_unregister(): test_func = mmdeploy.TestFunc.apply @SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc') - def symbolic_testfunc_default(symbolic_wrapper, g, x, val): + def symbolic_testfunc_default(g, x, val): return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val) @SYMBOLIC_REWRITER.register_symbolic( 'cummax', is_pytorch=True, arg_descriptors=['v', 'i']) - def symbolic_cummax(symbolic_wrapper, g, input, dim): + def symbolic_cummax(g, input, dim): return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2) class TestModel(torch.nn.Module): @@ -159,7 +160,7 @@ def test_register_empty_symbolic(): symbolic_rewriter = SymbolicRewriter() @symbolic_rewriter.register_symbolic('mmdeploy.EmptyFunction') - def symbolic_testfunc_default(symbolic_wrapper, g, x, val): + def symbolic_testfunc_default(g, x, val): return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val) symbolic_rewriter.enter() diff --git a/tests/test_mmcv/test_mmcv_cnn.py b/tests/test_mmcv/test_mmcv_cnn.py index 4ff02438b1..4961979529 100644 --- a/tests/test_mmcv/test_mmcv_cnn.py +++ b/tests/test_mmcv/test_mmcv_cnn.py @@ -30,30 +30,3 @@ def test_multiheadattention_ncnn(): else: assert torch.allclose( model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05) - - -def test_conv2d_adaptive_padding_tensorrt(): - check_backend(Backend.TENSORRT) - from mmcv.cnn.bricks.conv2d_adaptive_padding import Conv2dAdaptivePadding - in_channels, out_channels = 3, 64 - kernel_sz = 3 - model = Conv2dAdaptivePadding(in_channels, out_channels, kernel_sz) - dummy_input = torch.rand(1, 3, 256, 256) - - deploy_cfg = Config( - dict( - onnx_config=dict(input_shape=None), - backend_config=dict(type=Backend.TENSORRT.value), - )) - model_outputs = model(dummy_input) - rewrite_inputs = dict(x=dummy_input) - rewrite_outputs, is_backend_output = get_rewrite_outputs( - wrapped_model=model, - model_inputs=rewrite_inputs, - deploy_cfg=deploy_cfg, - run_with_backend=True) - if is_backend_output is None: - assert rewrite_outputs is not None - else: - assert torch.allclose( - model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05) diff --git a/tests/test_mmcv/test_mmcv_ops.py b/tests/test_mmcv/test_mmcv_ops.py index 4d41dc2fd6..6f1a3ff71d 100644 --- a/tests/test_mmcv/test_mmcv_ops.py +++ b/tests/test_mmcv/test_mmcv_ops.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp import tempfile import onnx import pytest import torch +from mmdeploy.apis.onnx import export from mmdeploy.core import RewriterContext from mmdeploy.utils import Backend from mmdeploy.utils.test import WrapFunction, check_backend @@ -36,16 +38,15 @@ def wrapped_function(torch_bboxes, torch_scores): wrapped_model = WrapFunction(wrapped_function).eval() result = wrapped_model(boxes, scores) assert result is not None - onnx_file_path = tempfile.NamedTemporaryFile().name - with RewriterContext({}, opset=11), torch.no_grad(): - torch.onnx.export( - wrapped_model, (boxes, scores), - onnx_file_path, - export_params=True, - keep_initializers_as_inputs=True, - input_names=['boxes', 'scores'], - output_names=['result'], - opset_version=11) + onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name + onnx_file_prefix = osp.splitext(onnx_file_path)[0] + export( + wrapped_model, (boxes, scores), + onnx_file_prefix, + keep_initializers_as_inputs=False, + input_names=['boxes', 'scores'], + output_names=['result'], + opset_version=11) model = onnx.load(onnx_file_path) assert model.graph.node[3].op_type == 'NonMaxSuppression' From 240ad95f763f1e432e2d9aa835af1d3a668f21b4 Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 3 Dec 2022 01:32:18 +0800 Subject: [PATCH 2/8] update rewriter --- mmdeploy/core/rewriters/function_rewriter.py | 33 ++++++++++++++------ mmdeploy/core/rewriters/rewriter_utils.py | 21 +++++++++++-- mmdeploy/core/rewriters/symbolic_rewriter.py | 33 ++++++++++++++------ 3 files changed, 67 insertions(+), 20 deletions(-) diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index e6eb207550..6d601beefa 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict from typing import (Any, Callable, Dict, List, MutableSequence, Optional, Tuple, Union) from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, - copy_function, get_frame_qual_name, - get_func_qual_name, import_function) + copy_function, get_frame_func, get_func_qual_name, + import_function) def _replace_all_obj(obj: Any, @@ -115,7 +116,7 @@ class FunctionRewriter: def __init__(self): self._registry = RewriterRegistry() - self._func_contexts = {} + self._func_contexts = defaultdict(list) def register_rewriter( self, @@ -142,7 +143,7 @@ def register_rewriter( def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): """The implementation of function rewrite.""" - self._func_contexts = {} + self._func_contexts.clear() # Get current records functions_records = self._registry.get_records(env) @@ -191,8 +192,8 @@ def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): cfg, **extra_kwargs) qualname = get_func_qual_name(rewrite_function) - self._func_contexts[qualname] = context_caller - self._func_contexts[function_path] = context_caller + self._func_contexts[qualname].append(context_caller) + self._func_contexts[function_path].append(context_caller) # Cache new the function to avoid homonymic bug new_functions.append( @@ -214,7 +215,7 @@ def exit(self): for func_path in self._additional_functions: _del_func(func_path) - self._func_contexts = {} + self._func_contexts.clear() def get_context(self, key: Optional[str] = None) -> ContextCaller: """Get the context of rewriter. @@ -225,9 +226,23 @@ def get_context(self, key: Optional[str] = None) -> ContextCaller: Returns: ContextCaller: context of function """ + func = None if key is None: - key = get_frame_qual_name(2) - ctx = self._func_contexts.get(key, None) + func = get_frame_func(2) + key = get_func_qual_name(func) + + # get all contexts + ctxs = self._func_contexts.get(key, []) + + if func is None: + assert len(ctxs) == 1 + return ctxs[0] + + ctx = None + for tmp_ctx in ctxs: + if tmp_ctx.func == func: + ctx = tmp_ctx + if ctx is None: get_root_logger().warning(f'Can not found context of {key}') return ctx diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 97cf931609..23c8dafa83 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -396,6 +396,19 @@ def get_func_qual_name(func: Callable) -> str: return _func_name +def get_frame_func(top: int = 1) -> Callable: + """get func of frame.""" + frameinfo = inspect.stack()[top] + frame = frameinfo.frame + + g_vars = frame.f_globals + func_name = frameinfo.function + assert func_name in g_vars, \ + f'Can not find function: {func_name} in global.' + func = g_vars[func_name] + return func + + def get_frame_qual_name(top: int = 1) -> str: """get frame name.""" frameinfo = inspect.stack()[top] @@ -413,12 +426,16 @@ def get_frame_qual_name(top: int = 1) -> str: def copy_function(f: types.FunctionType): """Copy the function.""" + # copy the global so we can get different func for different origin + glb = f.__globals__.copy() + name = f.__name__ g = types.FunctionType( f.__code__, - f.__globals__, - name=f.__name__, + glb, + name=name, argdefs=f.__defaults__, closure=f.__closure__) g = functools.update_wrapper(g, f) g.__kwdefaults__ = f.__kwdefaults__ + glb[name] = g return g diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index b4daa16104..118127c3a0 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict from typing import Callable, Dict, List, Optional, Sequence, Union import torch @@ -7,8 +8,8 @@ from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, - copy_function, eval_with_import, - get_frame_qual_name, get_func_qual_name) + copy_function, eval_with_import, get_frame_func, + get_func_qual_name) class SymbolicRewriter: @@ -35,7 +36,7 @@ class SymbolicRewriter: def __init__(self) -> None: self._registry = RewriterRegistry() - self._func_contexts = {} + self._func_contexts = defaultdict(list) def register_symbolic(self, func_name: str, @@ -78,7 +79,7 @@ def enter(self, **kwargs): """The implementation of symbolic register.""" # clear context - self._func_contexts = {} + self._func_contexts.clear() # Get current records symbolic_records = self._registry.get_records(env) @@ -98,8 +99,8 @@ def enter(self, # register context qualname = get_func_qual_name(symbolic_function) - self._func_contexts[qualname] = context_caller - self._func_contexts[function_name] = context_caller + self._func_contexts[qualname].append(context_caller) + self._func_contexts[function_name].append(context_caller) if arg_descriptors is not None and len(arg_descriptors) > 0: symbolic_function = parse_args(*arg_descriptors)( @@ -146,7 +147,7 @@ def enter(self, def exit(self): """The implementation of symbolic unregister.""" # clear context - self._func_contexts = {} + self._func_contexts.clear() # Unregister pytorch op if hasattr(torch.onnx, 'unregister_custom_op_symbolic'): @@ -175,9 +176,23 @@ def get_context(self, key: Optional[str] = None) -> ContextCaller: Returns: ContextCaller: context of function """ + func = None if key is None: - key = get_frame_qual_name(2) - ctx = self._func_contexts.get(key, None) + func = get_frame_func(2) + key = get_func_qual_name(func) + + # get all contexts + ctxs = self._func_contexts.get(key, []) + + if func is None: + assert len(ctxs) == 1 + return ctxs[0] + + ctx = None + for tmp_ctx in ctxs: + if tmp_ctx.func == func: + ctx = tmp_ctx + if ctx is None: get_root_logger().warning(f'Can not found context of {key}') return ctx From bee736c7a39bf031c8c2bc91ad3f19c9451b39c0 Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 3 Dec 2022 21:53:12 +0800 Subject: [PATCH 3/8] Support all codebase --- mmdeploy/codebase/mmdet3d/models/base.py | 3 +-- mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py | 8 +++----- mmdeploy/codebase/mmdet3d/models/pillar_encode.py | 2 +- mmdeploy/codebase/mmdet3d/models/pillar_scatter.py | 6 +----- .../codebase/mmedit/models/base_models/base_edit_model.py | 1 - mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py | 2 +- mmdeploy/codebase/mmocr/models/text_detection/heads.py | 4 ++-- .../models/text_detection/single_stage_text_detector.py | 1 - .../mmocr/models/text_recognition/base_decoder.py | 1 - .../mmocr/models/text_recognition/crnn_decoder.py | 2 +- .../models/text_recognition/encoder_decoder_recognizer.py | 2 +- .../codebase/mmocr/models/text_recognition/lstm_layer.py | 2 +- .../codebase/mmocr/models/text_recognition/sar_decoder.py | 5 +---- .../codebase/mmocr/models/text_recognition/sar_encoder.py | 1 - mmdeploy/codebase/mmpose/models/heads/mspn_head.py | 3 ++- mmdeploy/codebase/mmpose/models/pose_estimators/base.py | 2 +- mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py | 2 +- mmdeploy/codebase/mmseg/models/decode_heads/point_head.py | 5 +++-- mmdeploy/codebase/mmseg/models/segmentors/base.py | 4 ++-- .../mmseg/models/segmentors/cascade_encoder_decoder.py | 3 +-- .../codebase/mmseg/models/segmentors/encoder_decoder.py | 4 ++-- mmdeploy/codebase/mmseg/models/utils/up_conv_block.py | 3 ++- 22 files changed, 27 insertions(+), 39 deletions(-) diff --git a/mmdeploy/codebase/mmdet3d/models/base.py b/mmdeploy/codebase/mmdet3d/models/base.py index 38d35cd95c..4410e77e2e 100644 --- a/mmdeploy/codebase/mmdet3d/models/base.py +++ b/mmdeploy/codebase/mmdet3d/models/base.py @@ -9,8 +9,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.detectors.Base3DDetector.forward' # noqa: E501 ) -def basedetector__forward(ctx, - self, +def basedetector__forward(self, inputs: list, data_samples=None, **kwargs) -> Tuple[List[torch.Tensor]]: diff --git a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py index 83ee170885..12df74ff52 100644 --- a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py +++ b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py @@ -7,8 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_img_feat' # noqa: E501 ) -def mvxtwostagedetector__extract_img_feat(ctx, self, - img: torch.Tensor) -> dict: +def mvxtwostagedetector__extract_img_feat(self, img: torch.Tensor) -> dict: """Extract features of images.""" if self.with_img_backbone and img is not None: if img.dim() == 5 and img.size(0) == 1: @@ -26,8 +25,7 @@ def mvxtwostagedetector__extract_img_feat(ctx, self, @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_feat') -def mvxtwostagedetector__extract_feat(ctx, self, - batch_inputs_dict: dict) -> tuple: +def mvxtwostagedetector__extract_feat(self, batch_inputs_dict: dict) -> tuple: """Rewrite this func to remove voxelize op. Args: @@ -47,7 +45,7 @@ def mvxtwostagedetector__extract_feat(ctx, self, @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.forward') -def mvxtwostagedetector__forward(ctx, self, inputs: list, **kwargs): +def mvxtwostagedetector__forward(self, inputs: list, **kwargs): """Rewrite this func to remove voxelize op. Args: diff --git a/mmdeploy/codebase/mmdet3d/models/pillar_encode.py b/mmdeploy/codebase/mmdet3d/models/pillar_encode.py index 4908a57071..0a327ad29d 100644 --- a/mmdeploy/codebase/mmdet3d/models/pillar_encode.py +++ b/mmdeploy/codebase/mmdet3d/models/pillar_encode.py @@ -7,7 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.voxel_encoders.pillar_encoder.PillarFeatureNet.forward') -def pillar_encoder__forward(ctx, self, features, num_points, coors, *args, +def pillar_encoder__forward(self, features, num_points, coors, *args, **kwargs): """Rewrite this func to optimize node. Modify the code at _with_voxel_center and use slice instead of the original operation. diff --git a/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py b/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py index 66ae455b5c..34351ccbc9 100644 --- a/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py +++ b/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py @@ -7,11 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.middle_encoders.pillar_scatter.' 'PointPillarsScatter.forward_batch', ) -def pointpillarsscatter__forward(ctx, - self, - voxel_features, - coors, - batch_size=1): +def pointpillarsscatter__forward(self, voxel_features, coors, batch_size=1): """Scatter features of single sample. Args: diff --git a/mmdeploy/codebase/mmedit/models/base_models/base_edit_model.py b/mmdeploy/codebase/mmedit/models/base_models/base_edit_model.py index eb3dad7ddf..620165ec68 100644 --- a/mmdeploy/codebase/mmedit/models/base_models/base_edit_model.py +++ b/mmdeploy/codebase/mmedit/models/base_models/base_edit_model.py @@ -10,7 +10,6 @@ @FUNCTION_REWRITER.register_rewriter( 'mmedit.models.base_models.BaseEditModel.forward', backend='default') def base_edit_model__forward( - ctx, self, batch_inputs: Tensor, data_samples: Optional[List[BaseDataElement]] = None, diff --git a/mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py b/mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py index c188d7e566..39acc2ea57 100644 --- a/mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py +++ b/mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py @@ -8,7 +8,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textdet.FPNC.forward', backend='tensorrt') -def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs): +def fpnc__forward__tensorrt(self, inputs, **kwargs): """Rewrite `forward` of FPNC for tensorrt backend. Rewrite this function to replace nearest upsampling with bilinear diff --git a/mmdeploy/codebase/mmocr/models/text_detection/heads.py b/mmdeploy/codebase/mmocr/models/text_detection/heads.py index b7d855f902..8f6bf631e4 100644 --- a/mmdeploy/codebase/mmocr/models/text_detection/heads.py +++ b/mmdeploy/codebase/mmocr/models/text_detection/heads.py @@ -10,7 +10,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textdet.heads.BaseTextDetHead.predict') def base_text_det_head__predict( - ctx, self, x: torch.Tensor, + self, x: torch.Tensor, batch_data_samples: DetSampleList) -> DetSampleList: """Rewrite `predict` of BaseTextDetHead for default backend. @@ -38,7 +38,7 @@ def base_text_det_head__predict( @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textdet.heads.DBHead.predict') -def db_head__predict(ctx, self, x: torch.Tensor, +def db_head__predict(self, x: torch.Tensor, batch_data_samples: DetSampleList) -> DetSampleList: """Rewrite to avoid post-process of text detection head. diff --git a/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py b/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py index ea72eae89c..0313097afb 100644 --- a/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py +++ b/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py @@ -10,7 +10,6 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textdet.SingleStageTextDetector.forward') def single_stage_text_detector__forward( - ctx, self, batch_inputs: torch.Tensor, data_samples: TextDetDataSample = None, diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py index 26adccaec6..036e952189 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py @@ -10,7 +10,6 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textrecog.decoders.BaseDecoder.predict') def base_decoder__forward( - ctx, self, feat: Optional[torch.Tensor] = None, out_enc: Optional[torch.Tensor] = None, diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/crnn_decoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/crnn_decoder.py index 76cb318f68..ce0696e0ba 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/crnn_decoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/crnn_decoder.py @@ -5,7 +5,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textrecog.decoders.CRNNDecoder.forward_train', backend='ncnn') -def crnndecoder__forward_train__ncnn(ctx, self, feat, *args, **kwargs): +def crnndecoder__forward_train__ncnn(self, feat, *args, **kwargs): """Rewrite `forward_train` of CRNNDecoder for ncnn backend. Rewrite this function to skip permuting dims of outputs from `[W, N, C]` to diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py b/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py index 155ece62c4..041ab9758c 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py @@ -7,7 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textrecog.EncoderDecoderRecognizer.forward') -def encoder_decoder_recognizer__forward(ctx, self, batch_inputs: torch.Tensor, +def encoder_decoder_recognizer__forward(self, batch_inputs: torch.Tensor, data_samples: TextRecogDataSample, **kwargs) -> TextRecogDataSample: """Rewrite `forward` of EncoderDecoderRecognizer for default backend. diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/lstm_layer.py b/mmdeploy/codebase/mmocr/models/text_recognition/lstm_layer.py index bd0d5df345..b181daaaab 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/lstm_layer.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/lstm_layer.py @@ -6,7 +6,7 @@ func_name='mmocr.models.textrecog.layers.lstm_layer' '.BidirectionalLSTM.forward', backend='ncnn') -def bidirectionallstm__forward__ncnn(ctx, self, input): +def bidirectionallstm__forward__ncnn(self, input): """Rewrite `forward` of BidirectionalLSTM for ncnn backend. Rewrite this function to set batch_first of rnn layer to true. RNN in ncnn diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py index bc371d81c3..6a5d24a823 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py @@ -15,7 +15,6 @@ '._2d_attention', backend='default') def parallel_sar_decoder__2d_attention( - ctx, self, decoder_input: torch.Tensor, feat: torch.Tensor, @@ -85,8 +84,7 @@ def parallel_sar_decoder__2d_attention( func_name='mmocr.models.textrecog.decoders.SequentialSARDecoder' '._2d_attention', backend='default') -def sequential_sar_decoder__2d_attention(ctx, - self, +def sequential_sar_decoder__2d_attention(self, y_prev, feat, holistic_feat, @@ -151,7 +149,6 @@ def sequential_sar_decoder__2d_attention(ctx, '.forward_test', backend='default') def sequential_sar_decoder__forward_test( - ctx, self, feat: torch.Tensor, out_enc: torch.Tensor, diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py index 8c756fc00c..dc5a87f6f1 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py @@ -12,7 +12,6 @@ func_name='mmocr.models.textrecog.encoders.SAREncoder.forward', backend='default') def sar_encoder__forward( - ctx, self, feat: torch.Tensor, data_samples: Optional[Sequence[TextRecogDataSample]] = None): diff --git a/mmdeploy/codebase/mmpose/models/heads/mspn_head.py b/mmdeploy/codebase/mmpose/models/heads/mspn_head.py index 2c92d02719..7b391040f7 100644 --- a/mmdeploy/codebase/mmpose/models/heads/mspn_head.py +++ b/mmdeploy/codebase/mmpose/models/heads/mspn_head.py @@ -7,7 +7,7 @@ 'mmpose.models.heads.heatmap_heads.CPMHead.forward') @FUNCTION_REWRITER.register_rewriter( 'mmpose.models.heads.heatmap_heads.MSPNHead.forward') -def mspn_head__forward(ctx, self, feats): +def mspn_head__forward(self, feats): """Rewrite `forward` of MSPNHead and CPMHead for default backend. 1. return last stage heatmaps directly. @@ -18,6 +18,7 @@ def mspn_head__forward(ctx, self, feats): Returns: output_heatmap (torch.Tensor): Output heatmaps. """ + ctx = FUNCTION_REWRITER.get_context() msmu_batch_heatmaps = ctx.origin_func(self, feats) batch_heatmaps = msmu_batch_heatmaps[-1] return batch_heatmaps diff --git a/mmdeploy/codebase/mmpose/models/pose_estimators/base.py b/mmdeploy/codebase/mmpose/models/pose_estimators/base.py index a0e11e45f8..3962a4f3ff 100644 --- a/mmdeploy/codebase/mmpose/models/pose_estimators/base.py +++ b/mmdeploy/codebase/mmpose/models/pose_estimators/base.py @@ -4,7 +4,7 @@ @FUNCTION_REWRITER.register_rewriter( 'mmpose.models.pose_estimators.base.BasePoseEstimator.forward') -def base_pose_estimator__forward(ctx, self, inputs, *args, **kwargs): +def base_pose_estimator__forward(self, inputs, *args, **kwargs): """Rewrite `forward` of TopDown for default backend.'. 1.directly call _forward of subclass. diff --git a/mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py b/mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py index 5d839691b7..6ff07cd107 100644 --- a/mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py +++ b/mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py @@ -7,7 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.decode_heads.ema_head.EMAModule.forward') -def ema_module__forward(ctx, self, feats): +def ema_module__forward(self, feats): """Rewrite `forward` for default backend. Replace torch.einsum with other operations. diff --git a/mmdeploy/codebase/mmseg/models/decode_heads/point_head.py b/mmdeploy/codebase/mmseg/models/decode_heads/point_head.py index 717f5a7afc..09b863a875 100644 --- a/mmdeploy/codebase/mmseg/models/decode_heads/point_head.py +++ b/mmdeploy/codebase/mmseg/models/decode_heads/point_head.py @@ -7,8 +7,8 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.decode_heads.point_head.PointHead.get_points_test', backend='tensorrt') -def point_head__get_points_test__tensorrt(ctx, self, seg_logits, - uncertainty_func, cfg): +def point_head__get_points_test__tensorrt(self, seg_logits, uncertainty_func, + cfg): """Sample points for testing. 1. set `num_points` no greater than TENSORRT_MAX_TOPK for tensorrt backend @@ -26,6 +26,7 @@ def point_head__get_points_test__tensorrt(ctx, self, seg_logits, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the ``height x width`` grid . """ + ctx = FUNCTION_REWRITER.get_context() from mmdeploy.utils.constants import TENSORRT_MAX_TOPK if cfg.subdivision_num_points > TENSORRT_MAX_TOPK: diff --git a/mmdeploy/codebase/mmseg/models/segmentors/base.py b/mmdeploy/codebase/mmseg/models/segmentors/base.py index 68e3196221..3606074079 100644 --- a/mmdeploy/codebase/mmseg/models/segmentors/base.py +++ b/mmdeploy/codebase/mmseg/models/segmentors/base.py @@ -7,8 +7,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.BaseSegmentor.forward') -def base_segmentor__forward(ctx, - self, +def base_segmentor__forward(self, inputs, data_samples=None, mode='predict', @@ -27,6 +26,7 @@ def base_segmentor__forward(ctx, Returns: torch.Tensor: Output segmentation map pf shape [N, 1, H, W]. """ + ctx = FUNCTION_REWRITER.get_context() if data_samples is None: data_samples = [SegDataSample()] diff --git a/mmdeploy/codebase/mmseg/models/segmentors/cascade_encoder_decoder.py b/mmdeploy/codebase/mmseg/models/segmentors/cascade_encoder_decoder.py index 30828311ad..ad8d35b81f 100644 --- a/mmdeploy/codebase/mmseg/models/segmentors/cascade_encoder_decoder.py +++ b/mmdeploy/codebase/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -4,8 +4,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.CascadeEncoderDecoder.predict') -def cascade_encoder_decoder__predict(ctx, self, inputs, data_samples, - **kwargs): +def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs): """Rewrite `predict` for default backend. 1. only support mode=`whole` inference diff --git a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py index 332f39bed3..ee401b22bb 100644 --- a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py +++ b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py @@ -5,7 +5,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.EncoderDecoder.predict') -def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs): +def encoder_decoder__predict(self, inputs, data_samples, **kwargs): """Rewrite `predict` for default backend. 1. only support mode=`whole` inference @@ -32,7 +32,7 @@ def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs): @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.EncoderDecoder.predict', backend=Backend.RKNN.value) -def encoder_decoder__predict__rknn(ctx, self, inputs, data_samples, **kwargs): +def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs): """Rewrite `predict` for RKNN backend. Early return to avoid argmax operator. diff --git a/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py b/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py index 6ccf56f2b4..bc1029976e 100644 --- a/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py +++ b/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py @@ -8,7 +8,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.utils.UpConvBlock.forward') -def up_conv_block__forward(ctx, self, skip, x): +def up_conv_block__forward(self, skip, x): """Rewrite `forward` for default backend. To support dynamic shape for UNet backbone, @@ -23,6 +23,7 @@ def up_conv_block__forward(ctx, self, skip, x): Returns: Tensor: Upsampled output feature map. """ + ctx = FUNCTION_REWRITER.get_context() from mmcv.cnn import ConvModule # only valid when self.upsample is from build_upsample_layer From 4057523a9a3af0f8b64216ef9602cae43ed74d80 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 5 Dec 2022 11:25:34 +0800 Subject: [PATCH 4/8] update docs --- demo/demo_rewrite.py | 2 +- docs/en/07-developer-guide/partition_model.md | 7 +++---- docs/en/07-developer-guide/support_new_model.md | 5 +++-- .../07-developer-guide/test_rewritten_models.md | 5 +++-- docs/zh_cn/07-developer-guide/partition_model.md | 7 +++---- .../zh_cn/07-developer-guide/support_new_model.md | 5 +++-- .../07-developer-guide/test_rewritten_models.md | 5 +++-- mmdeploy/codebase/mmdet/deploy/utils.py | 9 ++++----- mmdeploy/codebase/mmdet/models/necks.py | 2 +- mmdeploy/core/optimizers/function_marker.py | 15 +++++++++------ mmdeploy/core/rewriters/function_rewriter.py | 3 ++- mmdeploy/core/rewriters/rewriter_utils.py | 3 ++- mmdeploy/core/rewriters/symbolic_rewriter.py | 2 +- mmdeploy/mmcv/ops/nms.py | 2 +- 14 files changed, 39 insertions(+), 33 deletions(-) diff --git a/demo/demo_rewrite.py b/demo/demo_rewrite.py index a624c26eba..a11bc9e0e1 100644 --- a/demo/demo_rewrite.py +++ b/demo/demo_rewrite.py @@ -13,7 +13,7 @@ @FUNCTION_REWRITER.register_rewriter( func_name='torchvision.models.ResNet._forward_impl') -def forward_of_resnet(ctx, self, x): +def forward_of_resnet(self, x): """Rewrite the forward implementation of resnet. Early return the feature map after two down-sampling steps. diff --git a/docs/en/07-developer-guide/partition_model.md b/docs/en/07-developer-guide/partition_model.md index 96aa8b73e2..f1f1420b02 100644 --- a/docs/en/07-developer-guide/partition_model.md +++ b/docs/en/07-developer-guide/partition_model.md @@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark @mark( 'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) -def __forward_impl(ctx, self, img, img_metas=None, **kwargs): +def __forward_impl(self, img, img_metas=None, **kwargs): ... @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.base.BaseDetector.forward') -def base_detector__forward(ctx, self, img, img_metas=None, **kwargs): +def base_detector__forward(self, img, img_metas=None, **kwargs): ... # call the mark function return __forward_impl(...) @@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes') -def yolov3_head__get_bboxes(ctx, - self, +def yolov3_head__get_bboxes(self, pred_maps, img_metas, cfg=None, diff --git a/docs/en/07-developer-guide/support_new_model.md b/docs/en/07-developer-guide/support_new_model.md index ae456a45b7..1fb4c012cf 100644 --- a/docs/en/07-developer-guide/support_new_model.md +++ b/docs/en/07-developer-guide/support_new_model.md @@ -11,7 +11,8 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.repeat', backend='tensorrt') -def repeat_static(ctx, input, *size): +def repeat_static(input, *size): + ctx = FUNCTION_REWRITER.get_context() origin_func = ctx.origin_func if input.dim() == 1 and len(size) == 1: return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0) @@ -72,7 +73,7 @@ The mappings between PyTorch and ONNX are defined in PyTorch with symbolic funct ```python @SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True) -def squeeze_default(ctx, g, self, dim=None): +def squeeze_default(g, self, dim=None): if dim is None: dims = [] for i, size in enumerate(self.type().sizes()): diff --git a/docs/en/07-developer-guide/test_rewritten_models.md b/docs/en/07-developer-guide/test_rewritten_models.md index 311e2adbd7..e81e79fe0f 100644 --- a/docs/en/07-developer-guide/test_rewritten_models.md +++ b/docs/en/07-developer-guide/test_rewritten_models.md @@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta): # Custom rewritten function @FUNCTION_REWRITER.register_rewriter( 'mmcls.models.classifiers.BaseClassifier.forward', backend='default') -def forward_of_base_classifier(ctx, self, img, *args, **kwargs): +def forward_of_base_classifier(self, img, *args, **kwargs): """Rewrite `forward` for default backend.""" return self.simple_test(img, {}) ``` @@ -63,7 +63,8 @@ In the first example, the output is generated in Python. Sometimes we may make b # Custom rewritten function @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.BaseSegmentor.forward') -def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs): +def base_segmentor__forward(self, img, img_metas=None, **kwargs): + ctx = FUNCTION_REWRITER.get_context() if img_metas is None: img_metas = {} assert isinstance(img_metas, dict) diff --git a/docs/zh_cn/07-developer-guide/partition_model.md b/docs/zh_cn/07-developer-guide/partition_model.md index 2356554d47..bfcaa1058e 100644 --- a/docs/zh_cn/07-developer-guide/partition_model.md +++ b/docs/zh_cn/07-developer-guide/partition_model.md @@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark @mark( 'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) -def __forward_impl(ctx, self, img, img_metas=None, **kwargs): +def __forward_impl(self, img, img_metas=None, **kwargs): ... @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.base.BaseDetector.forward') -def base_detector__forward(ctx, self, img, img_metas=None, **kwargs): +def base_detector__forward(self, img, img_metas=None, **kwargs): ... # call the mark function return __forward_impl(...) @@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes') -def yolov3_head__get_bboxes(ctx, - self, +def yolov3_head__get_bboxes(self, pred_maps, img_metas, cfg=None, diff --git a/docs/zh_cn/07-developer-guide/support_new_model.md b/docs/zh_cn/07-developer-guide/support_new_model.md index 7c9cd72ada..727a9a2350 100644 --- a/docs/zh_cn/07-developer-guide/support_new_model.md +++ b/docs/zh_cn/07-developer-guide/support_new_model.md @@ -10,7 +10,8 @@ PyTorch 神经网络是用 python 编写的,可以简化算法的开发。但 from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.repeat', backend='tensorrt') -def repeat_static(ctx, input, *size): +def repeat_static(input, *size): + ctx = FUNCTION_REWRITER.get_context() origin_func = ctx.origin_func if input.dim() == 1 and len(size) == 1: return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0) @@ -67,7 +68,7 @@ PyTorch 和 ONNX 之间的映射是通过 PyTorch 中的符号函数进行定义 ```python @SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True) -def squeeze_default(ctx, g, self, dim=None): +def squeeze_default(g, self, dim=None): if dim is None: dims = [] for i, size in enumerate(self.type().sizes()): diff --git a/docs/zh_cn/07-developer-guide/test_rewritten_models.md b/docs/zh_cn/07-developer-guide/test_rewritten_models.md index 0ae0111de4..16f3a96e03 100644 --- a/docs/zh_cn/07-developer-guide/test_rewritten_models.md +++ b/docs/zh_cn/07-developer-guide/test_rewritten_models.md @@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta): # Custom rewritten function @FUNCTION_REWRITER.register_rewriter( 'mmcls.models.classifiers.BaseClassifier.forward', backend='default') -def forward_of_base_classifier(ctx, self, img, *args, **kwargs): +def forward_of_base_classifier(self, img, *args, **kwargs): """Rewrite `forward` for default backend.""" return self.simple_test(img, {}) ``` @@ -63,7 +63,8 @@ def test_baseclassfier_forward(): # Custom rewritten function @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.BaseSegmentor.forward') -def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs): +def base_segmentor__forward(self, img, img_metas=None, **kwargs): + ctx = FUNCTION_REWRITER.get_context() if img_metas is None: img_metas = {} assert isinstance(img_metas, dict) diff --git a/mmdeploy/codebase/mmdet/deploy/utils.py b/mmdeploy/codebase/mmdet/deploy/utils.py index 656200234f..1ad62f97f0 100644 --- a/mmdeploy/codebase/mmdet/deploy/utils.py +++ b/mmdeploy/codebase/mmdet/deploy/utils.py @@ -76,7 +76,7 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes', backend='tensorrt', extra_checkers=LibVersionChecker('tensorrt', min_version='8')) -def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, +def clip_bboxes__trt8(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, max_shape: Union[Tensor, Sequence[int]]): """Clip bboxes for onnx. From TensorRT 8 we can do the operators on the tensors directly. @@ -223,12 +223,12 @@ def symbolic(g, x, inds): @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk', backend=Backend.TENSORRT.value) -def __gather_topk__trt(ctx, - *inputs: Sequence[torch.Tensor], +def __gather_topk__trt(*inputs: Sequence[torch.Tensor], inds: torch.Tensor, batch_size: int, is_batched: bool = True) -> Tuple[torch.Tensor]: """TensorRT gather_topk.""" + ctx = FUNCTION_REWRITER.get_context() _ = ctx if is_batched: index_shape = inds.shape @@ -253,8 +253,7 @@ def __gather_topk__trt(ctx, @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk', backend=Backend.COREML.value) -def __gather_topk__nonbatch(ctx, - *inputs: Sequence[torch.Tensor], +def __gather_topk__nonbatch(*inputs: Sequence[torch.Tensor], inds: torch.Tensor, batch_size: int, is_batched: bool = True) -> Tuple[torch.Tensor]: diff --git a/mmdeploy/codebase/mmdet/models/necks.py b/mmdeploy/codebase/mmdet/models/necks.py index 79c1d117e5..4ea29db5f5 100644 --- a/mmdeploy/codebase/mmdet/models/necks.py +++ b/mmdeploy/codebase/mmdet/models/necks.py @@ -35,6 +35,6 @@ def l2norm__forward__tensorrt(self, x): except Exception: logger.warning('Can not get TensorRT version.') if trt_version_major >= 8: - return l2norm__forward__default(ctx, self, x) + return l2norm__forward__default(self, x) else: return ctx.origin_func(self, x) diff --git a/mmdeploy/core/optimizers/function_marker.py b/mmdeploy/core/optimizers/function_marker.py index 0417fccfd9..41deef71bf 100644 --- a/mmdeploy/core/optimizers/function_marker.py +++ b/mmdeploy/core/optimizers/function_marker.py @@ -218,12 +218,15 @@ def mark(func_name: Optional[str] = None, >>> from mmdeploy.core import FUNCTION_REWRITER, mark >>> @FUNCTION_REWRITER.register_rewriter( >>> func_name='mmdet.models.roi_heads.ConvFCBBoxHead.forward') - >>> @mark( - >>> 'bbox_head_forward', - >>> inputs=['bbox_feats'], - >>> outputs=['cls_score', 'bbox_pred']) - >>> def forward_of_bbox_head(ctx, self, x): - >>> return ctx.origin_func(self, x) + >>> def forward_of_bbox_head(self, x): + >>> ctx = FUNCTION_REWRITER.get_context() + >>> @mark( + >>> 'bbox_head_forward', + >>> inputs=['bbox_feats'], + >>> outputs=['cls_score', 'bbox_pred']) + >>> def _impl(): + >>> return ctx.origin_func(self, x) + >>> return _impl() """ MARK_FUNCTION_COUNT[func_name] = 0 diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index 6d601beefa..f77e159fcc 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -104,7 +104,8 @@ class FunctionRewriter: Examples: >>> @FUNCTION_REWRITER.register_rewriter( >>> func_name='torch.Tensor.size', backend='ncnn') - >>> def size_of_tensor_static(ctx, self, *args): + >>> def size_of_tensor_static(self, *args): + >>> ctx = FUNCTION_REWRITER.get_context() >>> ret = ctx.origin_func(self, *args) >>> if isinstance(ret, torch.Tensor): >>> ret = int(ret) diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 23c8dafa83..74406a464c 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -344,8 +344,9 @@ class ContextCaller: Example: >>> @FUNCTION_REWRITER.register_rewriter(func_name='torch.add') - >>> def func(ctx, x, y): + >>> def func(x, y): >>> # ctx is an instance of ContextCaller + >>> ctx = FUNCTION_REWRITER.get_context() >>> print(ctx.cfg) >>> return x + y """ diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index 118127c3a0..d3d5d753bb 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -23,7 +23,7 @@ class SymbolicRewriter: Examples: >>> @SYMBOLIC_REWRITER.register_symbolic('squeeze', \ >>> is_pytorch=True) - >>> def squeeze_default(ctx, g, self, dim=None): + >>> def squeeze_default(g, self, dim=None): >>> if dim is None: >>> dims = [] >>> for i, size in enumerate(self.type().sizes()): diff --git a/mmdeploy/mmcv/ops/nms.py b/mmdeploy/mmcv/ops/nms.py index 89ca2d604c..d256f9bbe5 100644 --- a/mmdeploy/mmcv/ops/nms.py +++ b/mmdeploy/mmcv/ops/nms.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmcv.ops import nms from torch import Tensor from torch.onnx import symbolic_helper as sym_help @@ -33,6 +32,7 @@ def forward(ctx, boxes: Tensor, scores: Tensor, (num_selected_indices, 3) with each row of [batch_index, class_index, box_index]. """ + from mmcv.ops import nms batch_size, num_class, _ = scores.shape score_threshold = float(score_threshold) From 87f662c193f1d64975e2c081848736069f2506e1 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 5 Dec 2022 15:22:48 +0800 Subject: [PATCH 5/8] fix ssd --- mmdeploy/codebase/mmdet/deploy/utils.py | 1 - tests/test_codebase/test_mmdet/test_mmdet_models.py | 8 +------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/mmdeploy/codebase/mmdet/deploy/utils.py b/mmdeploy/codebase/mmdet/deploy/utils.py index 1ad62f97f0..a7dc0b6fb6 100644 --- a/mmdeploy/codebase/mmdet/deploy/utils.py +++ b/mmdeploy/codebase/mmdet/deploy/utils.py @@ -165,7 +165,6 @@ def __pad_with_value_if_necessary(x: Tensor, 'mmdeploy.codebase.mmdet.deploy.utils.__pad_with_value_if_necessary', backend=Backend.TENSORRT.value) def __pad_with_value_if_necessary__tensorrt( - ctx, x: Tensor, pad_dim: int, pad_size: int, diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 4fbc7f5faa..4ee6ea0a18 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -906,13 +906,7 @@ def test_forward_of_base_detector(model_cfg_path, backend): img = torch.randn(1, 3, 64, 64) from mmdet.structures import DetDataSample - from mmengine.structures import InstanceData - data_sample = DetDataSample() - img_meta = dict(img_shape=(800, 1216, 3)) - gt_instances = InstanceData(metainfo=img_meta) - gt_instances.bboxes = torch.rand((5, 4)) - gt_instances.labels = torch.rand((5, )) - data_sample.gt_instances = gt_instances + data_sample = DetDataSample(metainfo=dict(img_shape=(800, 1216, 3))) rewrite_inputs = {'batch_inputs': img} wrapped_model = WrapModel(model, 'forward', data_samples=[data_sample]) rewrite_outputs, _ = get_rewrite_outputs( From 4348cf279fdc62149912674e7578ca8f61eabdb1 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 7 Dec 2022 11:47:14 +0800 Subject: [PATCH 6/8] rename qualname --- mmdeploy/core/rewriters/function_rewriter.py | 6 +++--- mmdeploy/core/rewriters/rewriter_utils.py | 4 ++-- mmdeploy/core/rewriters/symbolic_rewriter.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index f77e159fcc..78736b65a4 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -5,7 +5,7 @@ from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, - copy_function, get_frame_func, get_func_qual_name, + copy_function, get_frame_func, get_func_qualname, import_function) @@ -192,7 +192,7 @@ def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): context_caller = ContextCaller(rewrite_function, origin_func, cfg, **extra_kwargs) - qualname = get_func_qual_name(rewrite_function) + qualname = get_func_qualname(rewrite_function) self._func_contexts[qualname].append(context_caller) self._func_contexts[function_path].append(context_caller) @@ -230,7 +230,7 @@ def get_context(self, key: Optional[str] = None) -> ContextCaller: func = None if key is None: func = get_frame_func(2) - key = get_func_qual_name(func) + key = get_func_qualname(func) # get all contexts ctxs = self._func_contexts.get(key, []) diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 74406a464c..ca1e989360 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -384,7 +384,7 @@ def wrapper(*args, **kwargs): return wrapper -def get_func_qual_name(func: Callable) -> str: +def get_func_qualname(func: Callable) -> str: """get function name.""" assert isinstance(func, Callable), f'{func} is not a Callable object.' _func_name = None @@ -410,7 +410,7 @@ def get_frame_func(top: int = 1) -> Callable: return func -def get_frame_qual_name(top: int = 1) -> str: +def get_frame_qualname(top: int = 1) -> str: """get frame name.""" frameinfo = inspect.stack()[top] frame = frameinfo.frame diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index d3d5d753bb..e045dcc356 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -9,7 +9,7 @@ from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, copy_function, eval_with_import, get_frame_func, - get_func_qual_name) + get_func_qualname) class SymbolicRewriter: @@ -98,7 +98,7 @@ def enter(self, **extra_kwargs) # register context - qualname = get_func_qual_name(symbolic_function) + qualname = get_func_qualname(symbolic_function) self._func_contexts[qualname].append(context_caller) self._func_contexts[function_name].append(context_caller) @@ -179,7 +179,7 @@ def get_context(self, key: Optional[str] = None) -> ContextCaller: func = None if key is None: func = get_frame_func(2) - key = get_func_qual_name(func) + key = get_func_qualname(func) # get all contexts ctxs = self._func_contexts.get(key, []) From 392946a632f83db28008ecb00f6d5d7e8dd3f9d2 Mon Sep 17 00:00:00 2001 From: pppppM Date: Fri, 9 Dec 2022 22:17:16 +0800 Subject: [PATCH 7/8] support torch.fx.wrap --- .../mmdet/models/detectors/single_stage.py | 36 +++++++++++++------ mmdeploy/core/rewriters/function_rewriter.py | 33 +++++++++++++++++ 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/mmdeploy/codebase/mmdet/models/detectors/single_stage.py b/mmdeploy/codebase/mmdet/models/detectors/single_stage.py index 9f256fc2d8..5f3872c8b6 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/single_stage.py +++ b/mmdeploy/codebase/mmdet/models/detectors/single_stage.py @@ -12,7 +12,7 @@ @mark( 'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) -def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs): +def __forward_impl(self, batch_inputs, data_samples): """Rewrite and adding mark for `forward`. Encapsulate this function for rewriting `forward` of BaseDetector. @@ -25,6 +25,27 @@ def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs): return output +@torch.fx.wrap +def _set_metainfo(data_samples, img_shape): + """Set the metainfo. + + Code in this function cannot be traced by fx. + """ + + # fx can not trace deepcopy correctly + data_samples = copy.deepcopy(data_samples) + if data_samples is None: + data_samples = [DetDataSample()] + + # note that we can not use `set_metainfo`, deepcopy would crash the + # onnx trace. + for data_sample in data_samples: + data_sample.set_field( + name='img_shape', value=img_shape, field_type='metainfo') + + return data_samples + + @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.single_stage.SingleStageDetector.forward') def single_stage_detector__forward(self, @@ -53,9 +74,7 @@ def single_stage_detector__forward(self, (num_instances, ). """ ctx = FUNCTION_REWRITER.get_context() - data_samples = copy.deepcopy(data_samples) - if data_samples is None: - data_samples = [DetDataSample()] + deploy_cfg = ctx.cfg # get origin input shape as tensor to support onnx dynamic shape @@ -65,11 +84,6 @@ def single_stage_detector__forward(self, img_shape = [int(val) for val in img_shape] # set the metainfo - # note that we can not use `set_metainfo`, deepcopy would crash the - # onnx trace. - for data_sample in data_samples: - data_sample.set_field( - name='img_shape', value=img_shape, field_type='metainfo') + data_samples = _set_metainfo(data_samples, img_shape) - return __forward_impl( - ctx, self, batch_inputs, data_samples=data_samples, **kwargs) + return __forward_impl(self, batch_inputs, data_samples=data_samples) diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index 78736b65a4..d0315c1cf4 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -1,8 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import types from collections import defaultdict from typing import (Any, Callable, Dict, List, MutableSequence, Optional, Tuple, Union) +from torch.fx._symbolic_trace import _wrapped_fns_to_patch + from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, copy_function, get_frame_func, get_func_qualname, @@ -94,6 +97,24 @@ def _del_func(path: str): continue +def _fx_wrap_copied_fn(func: types.FunctionType, + copied_func: types.FunctionType): + """If a function is wrapped by torch.fx.wrap, its copy also needs to be + wrapped by torch.fx.wrap.""" + if not hasattr(func, '__globals__'): + return + + wrapped_fns_globals = [item[0] for item in _wrapped_fns_to_patch] + wrapped_fns_names = [item[1] for item in _wrapped_fns_to_patch] + + # check if wrapped by torch.fx.wrap + if func.__globals__ in wrapped_fns_globals: + idx = wrapped_fns_globals.index(func.__globals__) + fn_name = wrapped_fns_names[idx] + # a hacky way to wrap the func in copied func + _wrapped_fns_to_patch.append((copied_func.__globals__, fn_name)) + + class FunctionRewriter: """A function rewriter which maintains rewritten functions. @@ -147,6 +168,8 @@ def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): self._func_contexts.clear() # Get current records functions_records = self._registry.get_records(env) + # Get current fx wrapped func nums + self._ori_fx_wrap_num = len(_wrapped_fns_to_patch) self._origin_functions = list() self._additional_functions = list() @@ -186,11 +209,16 @@ def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): # Create context_caller rewrite_function = record_dict['_object'] + # The func before and after copy has different globals rewrite_function = copy_function(rewrite_function) extra_kwargs = kwargs.copy() extra_kwargs.update(record_dict) context_caller = ContextCaller(rewrite_function, origin_func, cfg, **extra_kwargs) + # If there is a function wrapped by torch.fx.wrap in + # rewrite_function's globals, we need to wrap the same name + # function in copied function's globals. + _fx_wrap_copied_fn(record_dict['_object'], context_caller.func) qualname = get_func_qualname(rewrite_function) self._func_contexts[qualname].append(context_caller) @@ -209,6 +237,11 @@ def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): def exit(self): """Recover the function rewrite.""" + # Restore _wrapped_fns_to_patch + cur_fx_wrap_num = len(_wrapped_fns_to_patch) + for _ in range(cur_fx_wrap_num - self._ori_fx_wrap_num): + _wrapped_fns_to_patch.pop(-1) + for func_dict in self._origin_functions: func_path = func_dict['func_path'] func = func_dict['origin_func'] From 8f3f7814dda74d77cbd774d9a3999010d688dc32 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 12 Dec 2022 14:37:41 +0800 Subject: [PATCH 8/8] import by torch version --- mmdeploy/core/rewriters/function_rewriter.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index d0315c1cf4..7882f5fb30 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -4,13 +4,22 @@ from typing import (Any, Callable, Dict, List, MutableSequence, Optional, Tuple, Union) -from torch.fx._symbolic_trace import _wrapped_fns_to_patch - from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, copy_function, get_frame_func, get_func_qualname, import_function) +try: + try: + # torch>=1.10.0 + from torch.fx._symbolic_trace import _wrapped_fns_to_patch + except ImportError: + # 1.10.0>torch>=1.8.0 + from torch.fx.symbolic_trace import _wrapped_fns_to_patch +except ImportError: + # torch<1.8.0 + _wrapped_fns_to_patch = [] + def _replace_all_obj(obj: Any, new_obj: Any,