diff --git a/configs/mmdet/_base_/base_openvino_dynamic-640x640.py b/configs/mmdet/_base_/base_openvino_dynamic-640x640.py new file mode 100644 index 0000000000..29bb6712da --- /dev/null +++ b/configs/mmdet/_base_/base_openvino_dynamic-640x640.py @@ -0,0 +1,6 @@ +_base_ = ['./base_dynamic.py', '../../_base_/backends/openvino.py'] + +onnx_config = dict(input_shape=None) + +backend_config = dict( + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 640, 640]))]) diff --git a/configs/mmdet/detection/detection_openvino_dynamic-640x640.py b/configs/mmdet/detection/detection_openvino_dynamic-640x640.py new file mode 100644 index 0000000000..bf3bea574e --- /dev/null +++ b/configs/mmdet/detection/detection_openvino_dynamic-640x640.py @@ -0,0 +1 @@ +_base_ = ['../_base_/base_openvino_dynamic-640x640.py'] diff --git a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp index 3e1f171e3f..6d22a67b3a 100644 --- a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp +++ b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp @@ -85,6 +85,8 @@ class LinearClsHead : public MMClassification { }; MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMClassification, LinearClsHead); +using ConformerHead = LinearClsHead; +MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMClassification, ConformerHead); class CropBox { public: diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py index 56ff40b7a0..08121bdfbb 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py @@ -8,31 +8,6 @@ from mmdeploy.core import FUNCTION_REWRITER -@FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.dense_heads.DETRHead.forward_single') -def detrhead__forward_single__default(self, x, img_metas): - """forward_single of DETRHead. - - Ease the mask computation - """ - - batch_size = x.size(0) - - x = self.input_proj(x) - # interpolate masks to have the same spatial shape with x - masks = x.new_zeros((batch_size, x.size(-2), x.size(-1))).to(torch.bool) - - # position encoding - pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w] - # outs_dec: [nb_dec, bs, num_query, embed_dim] - outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight, - pos_embed) - all_cls_scores = self.fc_cls(outs_dec) - all_bbox_preds = self.fc_reg(self.activate( - self.reg_ffn(outs_dec))).sigmoid() - return all_cls_scores, all_bbox_preds - - @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.DETRHead.predict_by_feat') def detrhead__predict_by_feat__default(self, @@ -42,8 +17,8 @@ def detrhead__predict_by_feat__default(self, rescale: bool = True): """Rewrite `predict_by_feat` of `FoveaHead` for default backend.""" from mmdet.structures.bbox import bbox_cxcywh_to_xyxy - cls_scores = all_cls_scores_list[-1][-1] - bbox_preds = all_bbox_preds_list[-1][-1] + cls_scores = all_cls_scores_list[-1] + bbox_preds = all_bbox_preds_list[-1] img_shape = batch_img_metas[0]['img_shape'] max_per_img = self.test_cfg.get('max_per_img', len(cls_scores[0])) diff --git a/mmdeploy/codebase/mmdet/models/detectors/__init__.py b/mmdeploy/codebase/mmdet/models/detectors/__init__.py index 5b9df70a08..2c0a2f3ed5 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/__init__.py +++ b/mmdeploy/codebase/mmdet/models/detectors/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import single_stage, single_stage_instance_seg, two_stage +from . import base_detr, single_stage, single_stage_instance_seg, two_stage -__all__ = ['single_stage', 'single_stage_instance_seg', 'two_stage'] +__all__ = [ + 'base_detr', 'single_stage', 'single_stage_instance_seg', 'two_stage' +] diff --git a/mmdeploy/codebase/mmdet/models/detectors/base_detr.py b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py new file mode 100644 index 0000000000..3531c9183c --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/detectors/base_detr.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +from mmdet.models.detectors.base import ForwardResults +from mmdet.structures import DetDataSample +from mmdet.structures.det_data_sample import OptSampleList + +from mmdeploy.core import FUNCTION_REWRITER, mark +from mmdeploy.utils import is_dynamic_shape + + +@mark('detr_predict', inputs=['input'], outputs=['dets', 'labels', 'masks']) +def __predict_impl(self, batch_inputs, data_samples, rescale): + """Rewrite and adding mark for `predict`. + + Encapsulate this function for rewriting `predict` of DetectionTransformer. + 1. Add mark for DetectionTransformer. + 2. Support both dynamic and static export to onnx. + """ + img_feats = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer(img_feats, data_samples) + results_list = self.bbox_head.predict( + **head_inputs_dict, rescale=rescale, batch_data_samples=data_samples) + return results_list + + +@torch.fx.wrap +def _set_metainfo(data_samples, img_shape): + """Set the metainfo. + + Code in this function cannot be traced by fx. + """ + + # fx can not trace deepcopy correctly + data_samples = copy.deepcopy(data_samples) + if data_samples is None: + data_samples = [DetDataSample()] + + # note that we can not use `set_metainfo`, deepcopy would crash the + # onnx trace. + for data_sample in data_samples: + data_sample.set_field( + name='img_shape', value=img_shape, field_type='metainfo') + + return data_samples + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.detectors.base_detr.DetectionTransformer.predict') +def detection_transformer__predict(self, + batch_inputs: torch.Tensor, + data_samples: OptSampleList = None, + rescale: bool = True, + **kwargs) -> ForwardResults: + """Rewrite `predict` for default backend. + + Support configured dynamic/static shape for model input and return + detection result as Tensor instead of numpy array. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (Boolean): rescale result or not. + + Returns: + tuple[Tensor]: Detection results of the + input images. + - dets (Tensor): Classification bboxes and scores. + Has a shape (num_instances, 5) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + """ + ctx = FUNCTION_REWRITER.get_context() + + deploy_cfg = ctx.cfg + + # get origin input shape as tensor to support onnx dynamic shape + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + img_shape = torch._shape_as_tensor(batch_inputs)[2:] + if not is_dynamic_flag: + img_shape = [int(val) for val in img_shape] + + # set the metainfo + data_samples = _set_metainfo(data_samples, img_shape) + + return __predict_impl(self, batch_inputs, data_samples, rescale) diff --git a/mmdeploy/pytorch/functions/interpolate.py b/mmdeploy/pytorch/functions/interpolate.py index 39424b8a39..10340d37b7 100644 --- a/mmdeploy/pytorch/functions/interpolate.py +++ b/mmdeploy/pytorch/functions/interpolate.py @@ -81,7 +81,7 @@ def interpolate__tensorrt( size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]] = None, scale_factor: Optional[Union[float, Tuple[float]]] = None, - mode: str = 'bilinear', + mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, ): diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index 4f2a938f19..947dfe93cc 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -250,7 +250,9 @@ models: - *pipeline_ort_dynamic_fp32 - *pipeline_trt_dynamic_fp32 - *pipeline_ncnn_static_fp32 - - *pipeline_openvino_dynamic_fp32 + - deploy_config: configs/mmdet/detection/detection_openvino_dynamic-640x640.py + convert_image: *convert_image + backend_test: False - name: Faster R-CNN metafile: configs/faster_rcnn/metafile.yml @@ -298,7 +300,10 @@ models: - configs/detr/detr_r50_8xb2-150e_coco.py pipelines: - *pipeline_ort_dynamic_fp32 - - *pipeline_trt_dynamic_fp16 + - deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py + convert_image: *convert_image + backend_test: *default_backend_test + sdk_config: *sdk_dynamic - name: CenterNet metafile: configs/centernet/metafile.yml @@ -335,7 +340,7 @@ models: - configs/rtmdet/rtmdet_s_8xb32-300e_coco.py pipelines: - *pipeline_ort_dynamic_fp32 - - deploy_config: configs/mmdet/detection/detection_tensorrt_static-640x640.py + - deploy_config: configs/mmdet/detection/detection_tensorrt_dynamic-64x64-800x800.py convert_image: *convert_image backend_test: *default_backend_test sdk_config: *sdk_dynamic diff --git a/tests/test_codebase/test_mmcls/test_mmcls_models.py b/tests/test_codebase/test_mmcls/test_mmcls_models.py index a1037fc0c4..ba6a37b5e7 100644 --- a/tests/test_codebase/test_mmcls/test_mmcls_models.py +++ b/tests/test_codebase/test_mmcls/test_mmcls_models.py @@ -29,6 +29,14 @@ def get_invertedresidual_model(): return model +def get_fcuup_model(): + from mmcls.models.backbones.conformer import FCUUp + model = FCUUp(16, 16, 16) + + model.requires_grad_(False) + return model + + def get_vit_backbone(): from mmcls.models.classifiers.image import ImageClassifier model = ImageClassifier( diff --git a/tests/test_codebase/test_mmdet/data/detr_model.json b/tests/test_codebase/test_mmdet/data/detr_model.json new file mode 100644 index 0000000000..0d4c2ba6b6 --- /dev/null +++ b/tests/test_codebase/test_mmdet/data/detr_model.json @@ -0,0 +1,129 @@ +{ + "type": "DETR", + "num_queries": 100, + "data_preprocessor": { + "type": "DetDataPreprocessor", + "mean": [123.675, 116.28, 103.53], + "std": [58.395, 57.12, 57.375], + "bgr_to_rgb": true, + "pad_size_divisor": 1 + }, + "backbone": { + "type": "ResNet", + "depth": 50, + "num_stages": 4, + "out_indices": [3], + "frozen_stages": 1, + "norm_cfg": { + "type": "BN", + "requires_grad": false + }, + "norm_eval": true, + "style": "pytorch", + "init_cfg": { + "type": "Pretrained", + "checkpoint": "torchvision://resnet50" + } + }, + "neck": { + "type": "ChannelMapper", + "in_channels": [2048], + "kernel_size": 1, + "out_channels": 256, + "num_outs": 1 + }, + "encoder": { + "num_layers": 6, + "layer_cfg": { + "self_attn_cfg": { + "embed_dims": 256, + "num_heads": 8, + "dropout": 0.1, + "batch_first": true + }, + "ffn_cfg": { + "embed_dims": 256, + "feedforward_channels": 2048, + "num_fcs": 2, + "ffn_drop": 0.1, + "act_cfg": { + "type": "ReLU", + "inplace": true + } + } + } + }, + "decoder": { + "num_layers": 6, + "layer_cfg": { + "self_attn_cfg": { + "embed_dims": 256, + "num_heads": 8, + "dropout": 0.1, + "batch_first": true + }, + "cross_attn_cfg": { + "embed_dims": 256, + "num_heads": 8, + "dropout": 0.1, + "batch_first": true + }, + "ffn_cfg": { + "embed_dims": 256, + "feedforward_channels": 2048, + "num_fcs": 2, + "ffn_drop": 0.1, + "act_cfg": { + "type": "ReLU", + "inplace": true + } + } + }, + "return_intermediate": true + }, + "positional_encoding": { + "num_feats": 128, + "normalize": true + }, + "bbox_head": { + "type": "DETRHead", + "num_classes": 80, + "embed_dims": 256, + "loss_cls": { + "type": "CrossEntropyLoss", + "bg_cls_weight": 0.1, + "use_sigmoid": false, + "loss_weight": 1.0, + "class_weight": 1.0 + }, + "loss_bbox": { + "type": "L1Loss", + "loss_weight": 5.0 + }, + "loss_iou": { + "type": "GIoULoss", + "loss_weight": 2.0 + } + }, + "train_cfg": { + "assigner": { + "type": + "HungarianAssigner", + "match_costs": [{ + "type": "ClassificationCost", + "weight": 1.0 + }, { + "type": "BBoxL1Cost", + "weight": 5.0, + "box_format": "xywh" + }, { + "type": "IoUCost", + "iou_mode": "giou", + "weight": 2.0 + }] + } + }, + "test_cfg": { + "max_per_img": 100 + } +} diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 376806e793..32110c2b01 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -9,6 +9,7 @@ import numpy as np import pytest import torch +from packaging import version try: from torch.testing import assert_close as torch_assert_close @@ -691,6 +692,50 @@ def test_forward_of_base_detector(model_cfg_path, backend): assert rewrite_outputs is not None +@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) +@pytest.mark.skipif( + reason='mha only support torch greater than 1.10.0', + condition=version.parse(torch.__version__) < version.parse('1.10.0')) +@pytest.mark.parametrize( + 'model_cfg_path', ['tests/test_codebase/test_mmdet/data/detr_model.json']) +def test_predict_of_detr_detector(model_cfg_path, backend): + # Skip test when torch.__version__ < 1.10.0 + # See https://github.com/open-mmlab/mmdeploy/discussions/1434 + check_backend(backend) + deploy_cfg = Config( + dict( + backend_config=dict(type=backend.value), + onnx_config=dict( + output_names=['dets', 'labels'], input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=-1, + keep_top_k=100, + background_label_id=-1, + export_postprocess_mask=False, + )))) + model_cfg = Config(dict(model=mmengine.load(model_cfg_path))) + from mmdet.apis import init_detector + model = init_detector(model_cfg, None, device='cpu', palette='coco') + + img = torch.randn(1, 3, 64, 64) + from mmdet.structures import DetDataSample + data_sample = DetDataSample(metainfo=dict(batch_input_shape=(64, 64))) + rewrite_inputs = {'batch_inputs': img} + wrapped_model = WrapModel(model, 'predict', data_samples=[data_sample]) + rewrite_outputs, _ = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + assert rewrite_outputs is not None + + @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.OPENVINO]) def test_single_roi_extractor(backend_type: Backend): @@ -1995,7 +2040,7 @@ def test_mlvl_point_generator__single_level_grid_priors__tensorrt( @pytest.mark.parametrize('backend_type, ir_type', [(Backend.ONNXRUNTIME, 'onnx')]) def test_detrhead__predict_by_feat(backend_type: Backend, ir_type: str): - """Test predict_by_feat rewrite of base dense head.""" + """Test predict_by_feat rewrite of detr head.""" check_backend(backend_type) dense_head = get_detrhead_model() dense_head.cpu().eval() @@ -2009,9 +2054,9 @@ def test_detrhead__predict_by_feat(backend_type: Backend, ir_type: str): deploy_cfg = get_deploy_cfg(backend_type, ir_type) seed_everything(1234) - cls_score = [[torch.rand(1, 100, 5) for i in range(5, 0, -1)]] + cls_score = [torch.rand(1, 100, 5) for i in range(5, 0, -1)] seed_everything(5678) - bboxes = [[torch.rand(1, 100, 4) for i in range(5, 0, -1)]] + bboxes = [torch.rand(1, 100, 4) for i in range(5, 0, -1)] # to get outputs of onnx model after rewrite img_metas[0]['img_shape'] = torch.Tensor([s, s])