Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix unittest and suppress warning #1552

Merged
merged 3 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin(
int data_size = fc->fields[i].length;
const char *data_start = static_cast<const char *>(fc->fields[i].data);
std::string poolModeStr(data_start, data_size);
if (poolModeStr == "avg") {
if (strcmp(poolModeStr.c_str(), "avg") == 0) {
grimoire marked this conversation as resolved.
Show resolved Hide resolved
poolMode = 1;
} else if (poolModeStr == "max") {
} else if (strcmp(poolModeStr.c_str(), "max") == 0) {
poolMode = 0;
} else {
std::cout << "Unknown pool mode \"" << poolModeStr << "\"." << std::endl;
Expand Down
13 changes: 9 additions & 4 deletions mmdeploy/backend/tensorrt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
from .init_plugins import load_tensorrt_plugin


def save(engine: trt.ICudaEngine, path: str) -> None:
def save(engine: Any, path: str) -> None:
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
"""Serialize TensorRT engine to disk.

Args:
engine (tensorrt.ICudaEngine): TensorRT engine to be serialized.
engine (Any): TensorRT engine to be serialized.
path (str): The absolute disk path to write the engine.
"""
with open(path, mode='wb') as f:
f.write(bytearray(engine.serialize()))
if isinstance(engine, trt.ICudaEngine):
engine = engine.serialize()
f.write(bytearray(engine))


def load(path: str, allocator: Optional[Any] = None) -> trt.ICudaEngine:
Expand Down Expand Up @@ -226,7 +228,10 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto],
builder.int8_calibrator = config.int8_calibrator

# create engine
engine = builder.build_engine(network, config)
if hasattr(builder, 'build_serialized_network'):
engine = builder.build_serialized_network(network, config)
else:
engine = builder.build_engine(network, config)

assert engine is not None, 'Failed to create TensorRT engine'

Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ class labels of shape [N, num_det].
scores = out[:, :, 1:2]
boxes = out[:, :, 2:6] * scales
dets = torch.cat([boxes, scores], dim=2)
return dets, torch.tensor(labels, dtype=torch.int32)
return dets, labels.to(torch.int32)


@__BACKEND_MODEL.register_module('sdk')
Expand Down
15 changes: 10 additions & 5 deletions mmdeploy/codebase/mmdet/models/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def focus__forward__ncnn(self, x):

x = x.reshape(batch_size, c * h, 1, w)
_b, _c, _h, _w = x.shape
g = _c // 2
g = torch.div(_c, 2, rounding_mode='floor')
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
# fuse to ncnn's shufflechannel
x = x.view(_b, g, 2, _h, _w)
x = torch.transpose(x, 1, 2).contiguous()
Expand All @@ -55,13 +55,14 @@ def focus__forward__ncnn(self, x):
x = x.reshape(_b, c * h * w, 1, 1)

_b, _c, _h, _w = x.shape
g = _c // 2
g = torch.div(_c, 2, rounding_mode='floor')
# fuse to ncnn's shufflechannel
x = x.view(_b, g, 2, _h, _w)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(_b, -1, _h, _w)

x = x.reshape(_b, c * 4, h // 2, w // 2)
x = x.reshape(_b, c * 4, torch.div(h, 2, rounding_mode='floor'),
torch.div(w, 2, rounding_mode='floor'))

return self.conv(x)

Expand Down Expand Up @@ -198,8 +199,12 @@ def shift_window_msa__forward__default(self, query, hw_shape):
[query,
query.new_zeros(B, C, self.window_size, query.shape[-1])],
dim=-2)
slice_h = (H + self.window_size - 1) // self.window_size * self.window_size
slice_w = (W + self.window_size - 1) // self.window_size * self.window_size
slice_h = torch.div(
(H + self.window_size - 1), self.window_size,
rounding_mode='floor') * self.window_size
slice_w = torch.div(
(W + self.window_size - 1), self.window_size,
rounding_mode='floor') * self.window_size
query = query[:, :, :slice_h, :slice_w]
query = query.permute(0, 2, 3, 1).contiguous()
H_pad, W_pad = query.shape[1], query.shape[2]
Expand Down
23 changes: 23 additions & 0 deletions mmdeploy/core/rewriters/rewriter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,29 @@ def decorator(object):

return decorator

def remove_record(self, object: Any, filter_cb: Optional[Callable] = None):
"""Remove record.

Args:
object (Any): The object to remove.
filter_cb (Callable): Check if the object need to be remove.
Defaults to None.
"""
key_to_pop = []
for key, records in self._rewrite_records.items():
for rec in records:
if rec['_object'] == object:
if filter_cb is not None:
if filter_cb(rec):
continue
key_to_pop.append((key, rec))

for key, rec in key_to_pop:
records = self._rewrite_records[key]
records.remove(rec)
if len(records) == 0:
self._rewrite_records.pop(key)


class ContextCaller:
"""A callable object used in RewriteContext.
Expand Down
8 changes: 3 additions & 5 deletions mmdeploy/mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def multiclass_nms(boxes: Tensor,


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms',
backend=Backend.COREML.value)
def multiclass_nms__coreml(boxes: Tensor,
scores: Tensor,
Expand Down Expand Up @@ -574,8 +574,7 @@ def _xywh2xyxy(boxes):


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
ir=IR.TORCHSCRIPT)
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', ir=IR.TORCHSCRIPT)
def multiclass_nms__torchscript(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
Expand Down Expand Up @@ -676,8 +675,7 @@ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class,


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
backend='ascend')
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', backend='ascend')
def multiclass_nms__ascend(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
Expand Down
8 changes: 6 additions & 2 deletions mmdeploy/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from mmengine.model import BaseModel
from torch import nn

try:
from torch.testing import assert_close as torch_assert_close
except Exception:
from torch.testing import assert_allclose as torch_assert_close

import mmdeploy.codebase # noqa: F401,F403
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import (IR, Backend, get_backend, get_dynamic_axes,
Expand Down Expand Up @@ -289,8 +294,7 @@ def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
if isinstance(actual[i], (list, np.ndarray)):
actual[i] = torch.tensor(actual[i])
try:
torch.testing.assert_allclose(
actual[i], expected[i], rtol=1e-03, atol=1e-05)
torch_assert_close(actual[i], expected[i], rtol=1e-03, atol=1e-05)
except AssertionError as error:
if tolerate_small_mismatch:
assert '(0.00%)' in str(error), str(error)
Expand Down
113 changes: 12 additions & 101 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import numpy as np
import pytest
import torch

try:
from torch.testing import assert_close as torch_assert_close
except Exception:
from torch.testing import assert_allclose as torch_assert_close

from mmengine import Config
from mmengine.config import ConfigDict

Expand Down Expand Up @@ -237,7 +243,7 @@ def single_level_grid_priors(input):
# test forward
with RewriterContext({}, backend_type):
wrap_output = wrapped_func(x)
torch.testing.assert_allclose(output, wrap_output)
torch_assert_close(output, wrap_output)

onnx_prefix = tempfile.NamedTemporaryFile().name

Expand Down Expand Up @@ -341,23 +347,6 @@ def get_ssd_head_model():
return model


def get_fcos_head_model():
"""FCOS Head Config."""
test_cfg = Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))

from mmdet.models.dense_heads import FCOSHead
model = FCOSHead(num_classes=4, in_channels=1, test_cfg=test_cfg)

model.requires_grad_(False)
return model


def get_focus_backbone_model():
"""Backbone Focus Config."""
from mmdet.models.backbones.csp_darknet import Focus
Expand Down Expand Up @@ -412,10 +401,8 @@ def get_reppoints_head_model():

def get_detrhead_model():
"""DETR head Config."""
from mmdet.models import build_head
from mmdet.utils import register_all_modules
register_all_modules()
model = build_head(
from mmdet.registry import MODELS
model = MODELS.build(
dict(
type='DETRHead',
num_classes=4,
Expand All @@ -431,8 +418,7 @@ def get_detrhead_model():
dict(
type='MultiheadAttention',
embed_dims=4,
num_heads=1,
dropout=0.1)
num_heads=1)
],
ffn_cfgs=dict(
type='FFN',
Expand All @@ -442,8 +428,6 @@ def get_detrhead_model():
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
feedforward_channels=32,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DetrTransformerDecoder',
Expand All @@ -454,8 +438,7 @@ def get_detrhead_model():
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=4,
num_heads=1,
dropout=0.1),
num_heads=1),
ffn_cfgs=dict(
type='FFN',
embed_dims=4,
Expand All @@ -465,7 +448,6 @@ def get_detrhead_model():
act_cfg=dict(type='ReLU', inplace=True),
),
feedforward_channels=32,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn',
'norm', 'ffn', 'norm')),
)),
Expand Down Expand Up @@ -536,7 +518,7 @@ def test_focus_forward(backend_type):
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze()
rewrite_output = rewrite_output.squeeze()
torch.testing.assert_allclose(
torch_assert_close(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)


Expand Down Expand Up @@ -578,77 +560,6 @@ def test_l2norm_forward(backend_type):
model_output[0], rewrite_output, rtol=1e-03, atol=1e-05)


def test_predict_by_feat_of_fcos_head_ncnn():
backend_type = Backend.NCNN
check_backend(backend_type)
fcos_head = get_fcos_head_model()
fcos_head.cpu().eval()
s = 128
batch_img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]

output_names = ['detection_output']
deploy_cfg = Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
model_type='ncnn_end2end',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))))

# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
seed_everything(1234)
cls_score = [
torch.rand(1, fcos_head.num_classes, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]

seed_everything(9101)
centernesses = [
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]

# to get outputs of onnx model after rewrite
batch_img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
fcos_head,
'predict_by_feat',
batch_img_metas=batch_img_metas,
with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'centernesses': centernesses
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

# output should be of shape [1, N, 6]
if is_backend_output:
assert rewrite_outputs[0].shape[-1] == 6
else:
assert rewrite_outputs.shape[-1] == 6


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
def test_predict_by_feat_of_rpn_head(backend_type: Backend):
check_backend(backend_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def setup_class(cls):
deploy_cfg=deploy_cfg,
model_cfg=model_cfg)

@classmethod
def teardown_class(cls):
cls.wrapper.recover()

@pytest.mark.skipif(
reason='Only support GPU test',
condition=not torch.cuda.is_available())
Expand Down
Loading