Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support centernet dev1.x #1219

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3435296
support centernet head
hanrui1sensetime Oct 18, 2022
756fa08
add centernet head ut
hanrui1sensetime Oct 18, 2022
96644a3
add centernet
hanrui1sensetime Oct 18, 2022
fe97bd1
add centernet
hanrui1sensetime Oct 18, 2022
2f20161
add support models
hanrui1sensetime Oct 18, 2022
b23ab0f
fix mdformat
hanrui1sensetime Oct 19, 2022
3bd715f
fix reg test
hanrui1sensetime Oct 19, 2022
ba46743
fix scale
hanrui1sensetime Oct 19, 2022
6004ea6
fix conflicts
hanrui1sensetime Oct 24, 2022
fecd330
fix test.py show_dir kwargs
hanrui1sensetime Oct 31, 2022
77fec5c
fix for profile in T4
hanrui1sensetime Oct 31, 2022
d042cf8
fix dynamic shape
hanrui1sensetime Nov 1, 2022
adbf53e
fix lint
hanrui1sensetime Nov 1, 2022
3a53adc
move rescale and border to outside
hanrui1sensetime Nov 2, 2022
5269e8f
fix ut
hanrui1sensetime Nov 2, 2022
58755e1
fix lint
hanrui1sensetime Nov 2, 2022
a18754b
update ort torchscript benchmark
hanrui1sensetime Nov 2, 2022
9314817
fix centernet
hanrui1sensetime Nov 3, 2022
94cf03c
fix ut
hanrui1sensetime Nov 3, 2022
5ee8bd8
remove unused file
hanrui1sensetime Nov 4, 2022
1d49368
support centernet sdk
hanrui1sensetime Nov 7, 2022
b772d66
remove unused rewriter
hanrui1sensetime Nov 7, 2022
4621aa5
fix lint
hanrui1sensetime Nov 7, 2022
0cd31a8
fix flake8
hanrui1sensetime Nov 7, 2022
ce16874
remove unused line
hanrui1sensetime Nov 8, 2022
f034022
fix lint
hanrui1sensetime Nov 8, 2022
0cd10e3
fix lint
hanrui1sensetime Nov 8, 2022
deec016
fix conflict
hanrui1sensetime Nov 8, 2022
8e8706b
fix doc links
hanrui1sensetime Nov 8, 2022
ab8ce5a
fix mdformat
hanrui1sensetime Nov 8, 2022
22a0069
fix scale_factor as default
hanrui1sensetime Nov 8, 2022
7208924
fix conflict of docs
hanrui1sensetime Nov 9, 2022
106a5f9
apart random pad and pad
hanrui1sensetime Nov 9, 2022
e6d4d6c
fix sdk
hanrui1sensetime Nov 9, 2022
29c0bb1
fix centernet docs
hanrui1sensetime Nov 10, 2022
dc6694e
fix code style of cpp
hanrui1sensetime Nov 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_base_ = [
'../_base_/base_dynamic.py', '../../_base_/backends/tensorrt-fp16.py'
]

backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 64, 64],
opt_shape=[1, 3, 800, 800],
max_shape=[1, 3, 800, 800])))
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_base_ = [
'../_base_/base_dynamic.py', '../../_base_/backends/tensorrt-int8.py'
]

backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 64, 64],
opt_shape=[1, 3, 800, 800],
max_shape=[1, 3, 800, 800])))
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = ['../_base_/base_dynamic.py', '../../_base_/backends/tensorrt.py']

backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 64, 64],
opt_shape=[1, 3, 800, 800],
max_shape=[1, 3, 800, 800])))
])
4 changes: 4 additions & 0 deletions csrc/mmdeploy/codebase/mmdet/object_detection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ Result<Detections> ResizeBBox::GetBBoxes(const Value& prep_res, const Tensor& de

float w_offset = 0.f;
float h_offset = 0.f;
if (prep_res.contains("border")) {
w_offset = -prep_res["border"][1].get<int>();
h_offset = -prep_res["border"][0].get<int>();
}
int ori_width = prep_res["ori_shape"][2].get<int>();
int ori_height = prep_res["ori_shape"][1].get<int>();

Expand Down
4 changes: 3 additions & 1 deletion csrc/mmdeploy/preprocess/transform/default_format_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ Result<Value> DefaultFormatBundleImpl::Process(const Value& input) {
}
}
if (!output.contains("scale_factor")) {
output["scale_factor"].push_back(1.0);
for (int i = 0; i < 4; ++i) {
output["scale_factor"].push_back(1.0);
}
}
if (!output.contains("img_norm_cfg")) {
int channel = tensor.shape()[3];
Expand Down
15 changes: 15 additions & 0 deletions csrc/mmdeploy/preprocess/transform/pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ PadImpl::PadImpl(const Value& args) : TransformImpl(args) {
} else {
arg_.pad_val = 0.0f;
}
if (args.contains("logical_or_val")) {
// logical_or mode support.
arg_.logical_or_val = args["logical_or_val"].get<int>();
arg_.add_pix_val = args.value("add_pix_val", 0);
}
arg_.pad_to_square = args.value("pad_to_square", false);
arg_.padding_mode = args.value("padding_mode", std::string("constant"));
arg_.orientation_agnostic = args.value("orientation_agnostic", false);
Expand Down Expand Up @@ -80,6 +85,16 @@ Result<Value> PadImpl::Process(const Value& input) {
output["pad_size_divisor"] = arg_.size_divisor;
output["pad_fixed_size"].push_back(pad_h);
output["pad_fixed_size"].push_back(pad_w);
} else if (arg_.logical_or_val > 0) {
int pad_h = (height | arg_.logical_or_val) + arg_.add_pix_val;
int pad_w = (width | arg_.logical_or_val) + arg_.add_pix_val;
int offset_h = pad_h / 2 - height / 2;
int offset_w = pad_w / 2 - width / 2;
padding = {offset_w, offset_h, pad_w - width - offset_w, pad_h - height - offset_h};
output["border"].push_back(offset_h);
output["border"].push_back(offset_w);
output["border"].push_back(offset_h + height);
output["border"].push_back(offset_w + width);
} else {
output_tensor = tensor;
output["pad_fixed_size"].push_back(height);
Expand Down
2 changes: 2 additions & 0 deletions csrc/mmdeploy/preprocess/transform/pad.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class MMDEPLOY_API PadImpl : public TransformImpl {
protected:
struct pad_arg_t {
std::array<int, 2> size;
int logical_or_val;
int add_pix_val;
int size_divisor;
float pad_val;
bool pad_to_square;
Expand Down
14 changes: 14 additions & 0 deletions docs/en/03-benchmark/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,20 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../
<td align="center">37.4</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet/centernet_r18_8xb16-crop512-140e_coco.py">CenterNet</a></td>
<td align="center">Object Detection</td>
<td align="center">COCO2017</td>
<td align="center">box AP</td>
<td align="center">25.9</td>
<td align="center">26.0</td>
<td align="center">26.0</td>
<td align="center">26.0</td>
<td align="center">25.8</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox/yolox_s_8x8_300e_coco.py">YOLOX</a></td>
<td align="center">Object Detection</td>
Expand Down
1 change: 1 addition & 0 deletions docs/en/03-benchmark/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | MMDetection | N | N | N | N | N | Y | N | N |
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | MMDetection | N | N | Y | N | ? | Y | N | N |
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | MMDetection | N | Y | Y | N | ? | N | N | N |
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | MMDetection | N | Y | Y | N | ? | N | N | N |
| [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
| [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
| [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
Expand Down
1 change: 1 addition & 0 deletions docs/en/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | Object Detection | Y | Y | N | ? | Y |
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | Object Detection | N | Y | N | ? | Y |
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | Object Detection | Y | Y | N | ? | Y |
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | ? |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | Instance Segmentation | Y | N | N | N | Y |
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | Instance Segmentation | Y | Y | N | N | N |
14 changes: 14 additions & 0 deletions docs/zh_cn/03-benchmark/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,20 @@ GPU: ncnn, TensorRT, PPLNN
<td align="center">37.4</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet/centernet_r18_8xb16-crop512-140e_coco.py">CenterNet</a></td>
<td align="center">Object Detection</td>
<td align="center">COCO2017</td>
<td align="center">box AP</td>
<td align="center">25.9</td>
<td align="center">26.0</td>
<td align="center">26.0</td>
<td align="center">26.0</td>
<td align="center">25.8</td>
<td align="center">-</td>
<td align="center">-</td>
<td align="center">-</td>
</tr>
<tr>
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox/yolox_s_8x8_300e_coco.py">YOLOX</a></td>
<td align="center">Object Detection</td>
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/03-benchmark/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
| [VFNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/vfnet) | MMDetection | N | N | N | N | N | Y | N | N |
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | MMDetection | N | N | Y | N | ? | Y | N | N |
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | MMDetection | N | Y | Y | N | ? | N | N | N |
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | MMDetection | N | Y | Y | N | ? | N | N | N |
| [ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
| [ResNeXt](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/resnext) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
| [SE-ResNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/seresnet) | MMClassification | Y | Y | Y | Y | Y | Y | Y | Y |
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ cv2.imwrite('output_detection.png', img)
| [GFL](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/gfl) | ObjectDetection | Y | Y | N | ? | Y |
| [RepPoints](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/reppoints) | ObjectDetection | N | Y | N | ? | Y |
| [DETR](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/detr) | ObjectDetection | Y | Y | N | ? | Y |
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/centernet) | Object Detection | Y | Y | N | ? | ? |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/cascade_rcnn) | InstanceSegmentation | Y | N | N | N | Y |
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/mask_rcnn) | InstanceSegmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/swin) | InstanceSegmentation | Y | Y | N | N | N |
5 changes: 2 additions & 3 deletions mmdeploy/codebase/base/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ def _merge_cfg(cfg):
' field of config. Please set '\
'`visualization=dict(type="VisualizationHook")`'

cfg.default_hooks.visualization.enable = True
cfg.default_hooks.visualization.draw = True
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
cfg.default_hooks.visualization.show = show
cfg.default_hooks.visualization.wait_time = wait_time
cfg.default_hooks.visualization.out_dir = show_dir
cfg.default_hooks.visualization.test_out_dir = show_dir
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
cfg.default_hooks.visualization.interval = interval

return cfg
Expand All @@ -202,7 +202,6 @@ def _merge_cfg(cfg):
model_cfg = _merge_cfg(model_cfg)

visualizer = self.get_visualizer(work_dir, work_dir)

from .runner import DeployTestRunner
runner = DeployTestRunner(
model=model,
Expand Down
11 changes: 11 additions & 0 deletions mmdeploy/codebase/mmdet/deploy/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,16 @@ def get_preprocess(self, *args, **kwargs) -> Dict:
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
'valid_ratio'
]
# Extra pad outside datapreprocessor for CenterNet, CornerNet, etc.
for i, transform in enumerate(pipeline):
if transform['type'] == 'RandomCenterCropPad':
if transform['test_pad_mode'][0] == 'logical_or':
extra_pad = dict(
type='Pad',
logical_or_val=transform['test_pad_mode'][1],
add_pix_val=transform['test_pad_add_pix'],
)
pipeline[i] = extra_pad
transforms = [
item for item in pipeline if 'Random' not in item['type']
and 'Annotation' not in item['type']
Expand All @@ -249,6 +259,7 @@ def get_preprocess(self, *args, **kwargs) -> Dict:
transforms[i]['size'] = transforms[i].pop('scale')

data_preprocessor = model_cfg.model.data_preprocessor

transforms.insert(-1, dict(type='DefaultFormatBundle'))
transforms.insert(
-2,
Expand Down
20 changes: 13 additions & 7 deletions mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,8 @@ def forward(self,

bboxes = dets[:, :4]
scores = dets[:, 4]

# perform rescale
if rescale:
if rescale and 'scale_factor' in img_metas[i]:
scale_factor = img_metas[i]['scale_factor']
if isinstance(scale_factor, (list, tuple, np.ndarray)):
if len(scale_factor) == 2:
Expand All @@ -201,12 +200,19 @@ def forward(self,
scale_factor = torch.from_numpy(scale_factor).to(dets)
bboxes /= scale_factor

# Most of models in mmdetection 3.x use `pad_param`, but some
# models like CenterNet uses `border`.
# offset pixel of the top-left corners between original image
# and padded/enlarged image, 'pad_param' is used when exporting
# CornerNet and CentripetalNet to onnx
pad_key = None
hanrui1sensetime marked this conversation as resolved.
Show resolved Hide resolved
if 'pad_param' in img_metas[i]:
# offset pixel of the top-left corners between original image
# and padded/enlarged image, 'pad_param' is used when exporting
# CornerNet and CentripetalNet to onnx
x_off = img_metas[i]['pad_param'][2]
y_off = img_metas[i]['pad_param'][0]
pad_key = 'pad_param'
elif 'border' in img_metas[i]:
pad_key = 'border'
if pad_key is not None:
x_off = img_metas[i][pad_key][2]
y_off = img_metas[i][pad_key][0]
bboxes[:, ::2] -= x_off
bboxes[:, 1::2] -= y_off
bboxes *= (bboxes > 0)
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import base_dense_head # noqa: F401,F403
from . import centernet_head # noqa: F401,F403
from . import detr_head # noqa: F401,F403
from . import fovea_head # noqa: F401,F403
from . import gfl_head # noqa: F401,F403
Expand Down
45 changes: 45 additions & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/centernet_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.centernet_head.CenterNetHead.predict_by_feat')
def centernet_head__predict_by_feat__default(
ctx,
self,
center_heatmap_preds: List[Tensor],
wh_preds: List[Tensor],
offset_preds: List[Tensor],
batch_img_metas: List[dict],
rescale: bool = True,
with_nms: bool = False):
"""Rewrite `centernethead` of `CenterNetHead` for default backend."""

# The dynamic shape deploy of CenterNet get wrong result on TensorRT-8.4.x
# because of TensorRT bugs, https://github.com/NVIDIA/TensorRT/issues/2299,
# FYI.

assert len(center_heatmap_preds) == len(wh_preds) == len(offset_preds) == 1
batch_center_heatmap_preds = center_heatmap_preds[0]
batch_wh_preds = wh_preds[0]
batch_offset_preds = offset_preds[0]
batch_size = batch_center_heatmap_preds.shape[0]
img_shape = batch_img_metas[0]['img_shape']
batch_det_bboxes, batch_labels = self._decode_heatmap(
batch_center_heatmap_preds,
batch_wh_preds,
batch_offset_preds,
img_shape,
k=self.test_cfg.topk,
kernel=self.test_cfg.local_maximum_kernel)
det_bboxes = batch_det_bboxes.reshape([batch_size, -1, 5])
det_labels = batch_labels.reshape(batch_size, -1)

if with_nms:
det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels,
self.test_cfg)
return det_bboxes, det_labels
11 changes: 11 additions & 0 deletions tests/regression/mmdet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ models:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16

- name: CenterNet
hanrui1sensetime marked this conversation as resolved.
Show resolved Hide resolved
metafile: configs/centernet/metafile.yml
model_configs:
- configs/centernet/centernet_r18_8xb16-crop512-140e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- deploy_config: configs/mmdet/detection/detection_tensorrt-fp16_dynamic-64x64-800x800.py
hanrui1sensetime marked this conversation as resolved.
Show resolved Hide resolved
convert_image: *convert_image
backend_test: *default_backend_test
sdk_config: *sdk_dynamic

- name: Mask R-CNN
metafile: configs/mask_rcnn/metafile.yml
model_configs:
Expand Down
Loading