From e76c0a8a8ba24b1254ef299d2efc6bf9feeb57b8 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 16 Dec 2022 21:33:12 +0800 Subject: [PATCH 1/3] fix unittest and some warning --- .../tensorrt/roi_align/trt_roi_align.cpp | 4 +- mmdeploy/backend/tensorrt/utils.py | 13 +- .../mmdet/deploy/object_detection_model.py | 2 +- mmdeploy/codebase/mmdet/models/backbones.py | 15 ++- mmdeploy/core/rewriters/rewriter_utils.py | 23 ++++ mmdeploy/mmcv/ops/nms.py | 8 +- mmdeploy/utils/test.py | 8 +- .../test_mmdet/test_mmdet_models.py | 113 ++---------------- .../test_voxel_detection_model.py | 4 + .../test_mmpose/test_mmpose_models.py | 25 ---- tests/test_core/test_function_rewriter.py | 27 +++-- tests/test_core/test_module_rewriter.py | 8 +- tests/test_ops/test_ops.py | 18 ++- 13 files changed, 101 insertions(+), 167 deletions(-) diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp index 4f1221f2cc..8e4d556c14 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp @@ -203,9 +203,9 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin( int data_size = fc->fields[i].length; const char *data_start = static_cast(fc->fields[i].data); std::string poolModeStr(data_start, data_size); - if (poolModeStr == "avg") { + if (strcmp(poolModeStr.c_str(), "avg") == 0) { poolMode = 1; - } else if (poolModeStr == "max") { + } else if (strcmp(poolModeStr.c_str(), "max") == 0) { poolMode = 0; } else { std::cout << "Unknown pool mode \"" << poolModeStr << "\"." << std::endl; diff --git a/mmdeploy/backend/tensorrt/utils.py b/mmdeploy/backend/tensorrt/utils.py index 0fcdd306a2..088abaf6c8 100644 --- a/mmdeploy/backend/tensorrt/utils.py +++ b/mmdeploy/backend/tensorrt/utils.py @@ -13,15 +13,17 @@ from .init_plugins import load_tensorrt_plugin -def save(engine: trt.ICudaEngine, path: str) -> None: +def save(engine: Any, path: str) -> None: """Serialize TensorRT engine to disk. Args: - engine (tensorrt.ICudaEngine): TensorRT engine to be serialized. + engine (Any): TensorRT engine to be serialized. path (str): The absolute disk path to write the engine. """ with open(path, mode='wb') as f: - f.write(bytearray(engine.serialize())) + if isinstance(engine, trt.ICudaEngine): + engine = engine.serialize() + f.write(bytearray(engine)) def load(path: str, allocator: Optional[Any] = None) -> trt.ICudaEngine: @@ -226,7 +228,10 @@ def from_onnx(onnx_model: Union[str, onnx.ModelProto], builder.int8_calibrator = config.int8_calibrator # create engine - engine = builder.build_engine(network, config) + if hasattr(builder, 'build_serialized_network'): + engine = builder.build_serialized_network(network, config) + else: + engine = builder.build_engine(network, config) assert engine is not None, 'Failed to create TensorRT engine' diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index 2e06c7ae3f..61afb17eff 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -600,7 +600,7 @@ class labels of shape [N, num_det]. scores = out[:, :, 1:2] boxes = out[:, :, 2:6] * scales dets = torch.cat([boxes, scores], dim=2) - return dets, torch.tensor(labels, dtype=torch.int32) + return dets, labels.to(torch.int32) @__BACKEND_MODEL.register_module('sdk') diff --git a/mmdeploy/codebase/mmdet/models/backbones.py b/mmdeploy/codebase/mmdet/models/backbones.py index 6f6a72d5ca..f21853924b 100644 --- a/mmdeploy/codebase/mmdet/models/backbones.py +++ b/mmdeploy/codebase/mmdet/models/backbones.py @@ -46,7 +46,7 @@ def focus__forward__ncnn(self, x): x = x.reshape(batch_size, c * h, 1, w) _b, _c, _h, _w = x.shape - g = _c // 2 + g = torch.div(_c, 2, rounding_mode='floor') # fuse to ncnn's shufflechannel x = x.view(_b, g, 2, _h, _w) x = torch.transpose(x, 1, 2).contiguous() @@ -55,13 +55,14 @@ def focus__forward__ncnn(self, x): x = x.reshape(_b, c * h * w, 1, 1) _b, _c, _h, _w = x.shape - g = _c // 2 + g = torch.div(_c, 2, rounding_mode='floor') # fuse to ncnn's shufflechannel x = x.view(_b, g, 2, _h, _w) x = torch.transpose(x, 1, 2).contiguous() x = x.view(_b, -1, _h, _w) - x = x.reshape(_b, c * 4, h // 2, w // 2) + x = x.reshape(_b, c * 4, torch.div(h, 2, rounding_mode='floor'), + torch.div(w, 2, rounding_mode='floor')) return self.conv(x) @@ -198,8 +199,12 @@ def shift_window_msa__forward__default(self, query, hw_shape): [query, query.new_zeros(B, C, self.window_size, query.shape[-1])], dim=-2) - slice_h = (H + self.window_size - 1) // self.window_size * self.window_size - slice_w = (W + self.window_size - 1) // self.window_size * self.window_size + slice_h = torch.div( + (H + self.window_size - 1), self.window_size, + rounding_mode='floor') * self.window_size + slice_w = torch.div( + (W + self.window_size - 1), self.window_size, + rounding_mode='floor') * self.window_size query = query[:, :, :slice_h, :slice_w] query = query.permute(0, 2, 3, 1).contiguous() H_pad, W_pad = query.shape[1], query.shape[2] diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index ca1e989360..aada852aa8 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -328,6 +328,29 @@ def decorator(object): return decorator + def remove_record(self, object: Any, filter_cb: Optional[Callable] = None): + """Remove record. + + Args: + object (Any): The object to remove. + filter_cb (Callable): Check if the object need to be remove. + Defaults to None. + """ + key_to_pop = [] + for key, records in self._rewrite_records.items(): + for rec in records: + if rec['_object'] == object: + if filter_cb is not None: + if filter_cb(rec): + continue + key_to_pop.append((key, rec)) + + for key, rec in key_to_pop: + records = self._rewrite_records[key] + records.remove(rec) + if len(records) == 0: + self._rewrite_records.pop(key) + class ContextCaller: """A callable object used in RewriteContext. diff --git a/mmdeploy/mmcv/ops/nms.py b/mmdeploy/mmcv/ops/nms.py index 8a29f0b168..3abd9a9d02 100644 --- a/mmdeploy/mmcv/ops/nms.py +++ b/mmdeploy/mmcv/ops/nms.py @@ -511,7 +511,7 @@ def multiclass_nms(boxes: Tensor, @FUNCTION_REWRITER.register_rewriter( - func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms', + func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', backend=Backend.COREML.value) def multiclass_nms__coreml(boxes: Tensor, scores: Tensor, @@ -574,8 +574,7 @@ def _xywh2xyxy(boxes): @FUNCTION_REWRITER.register_rewriter( - func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms', - ir=IR.TORCHSCRIPT) + func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', ir=IR.TORCHSCRIPT) def multiclass_nms__torchscript(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, @@ -676,8 +675,7 @@ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, @FUNCTION_REWRITER.register_rewriter( - func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms', - backend='ascend') + func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', backend='ascend') def multiclass_nms__ascend(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, diff --git a/mmdeploy/utils/test.py b/mmdeploy/utils/test.py index 407320bdab..89970e53ef 100644 --- a/mmdeploy/utils/test.py +++ b/mmdeploy/utils/test.py @@ -14,6 +14,11 @@ from mmengine.model import BaseModel from torch import nn +try: + from torch.testing import assert_close as torch_assert_close +except Exception: + from torch.testing import assert_allclose as torch_assert_close + import mmdeploy.codebase # noqa: F401,F403 from mmdeploy.core import RewriterContext, patch_model from mmdeploy.utils import (IR, Backend, get_backend, get_dynamic_axes, @@ -289,8 +294,7 @@ def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]], if isinstance(actual[i], (list, np.ndarray)): actual[i] = torch.tensor(actual[i]) try: - torch.testing.assert_allclose( - actual[i], expected[i], rtol=1e-03, atol=1e-05) + torch_assert_close(actual[i], expected[i], rtol=1e-03, atol=1e-05) except AssertionError as error: if tolerate_small_mismatch: assert '(0.00%)' in str(error), str(error) diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 2ad56bda9c..98b5cabcb4 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -9,6 +9,12 @@ import numpy as np import pytest import torch + +try: + from torch.testing import assert_close as torch_assert_close +except Exception: + from torch.testing import assert_allclose as torch_assert_close + from mmengine import Config from mmengine.config import ConfigDict @@ -237,7 +243,7 @@ def single_level_grid_priors(input): # test forward with RewriterContext({}, backend_type): wrap_output = wrapped_func(x) - torch.testing.assert_allclose(output, wrap_output) + torch_assert_close(output, wrap_output) onnx_prefix = tempfile.NamedTemporaryFile().name @@ -341,23 +347,6 @@ def get_ssd_head_model(): return model -def get_fcos_head_model(): - """FCOS Head Config.""" - test_cfg = Config( - dict( - deploy_nms_pre=0, - min_bbox_size=0, - score_thr=0.05, - nms=dict(type='nms', iou_threshold=0.5), - max_per_img=100)) - - from mmdet.models.dense_heads import FCOSHead - model = FCOSHead(num_classes=4, in_channels=1, test_cfg=test_cfg) - - model.requires_grad_(False) - return model - - def get_focus_backbone_model(): """Backbone Focus Config.""" from mmdet.models.backbones.csp_darknet import Focus @@ -412,10 +401,8 @@ def get_reppoints_head_model(): def get_detrhead_model(): """DETR head Config.""" - from mmdet.models import build_head - from mmdet.utils import register_all_modules - register_all_modules() - model = build_head( + from mmdet.registry import MODELS + model = MODELS.build( dict( type='DETRHead', num_classes=4, @@ -431,8 +418,7 @@ def get_detrhead_model(): dict( type='MultiheadAttention', embed_dims=4, - num_heads=1, - dropout=0.1) + num_heads=1) ], ffn_cfgs=dict( type='FFN', @@ -442,8 +428,6 @@ def get_detrhead_model(): ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ), - feedforward_channels=32, - ffn_dropout=0.1, operation_order=('self_attn', 'norm', 'ffn', 'norm'))), decoder=dict( type='DetrTransformerDecoder', @@ -454,8 +438,7 @@ def get_detrhead_model(): attn_cfgs=dict( type='MultiheadAttention', embed_dims=4, - num_heads=1, - dropout=0.1), + num_heads=1), ffn_cfgs=dict( type='FFN', embed_dims=4, @@ -465,7 +448,6 @@ def get_detrhead_model(): act_cfg=dict(type='ReLU', inplace=True), ), feedforward_channels=32, - ffn_dropout=0.1, operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')), )), @@ -536,7 +518,7 @@ def test_focus_forward(backend_type): for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs): model_output = model_output.squeeze() rewrite_output = rewrite_output.squeeze() - torch.testing.assert_allclose( + torch_assert_close( model_output, rewrite_output, rtol=1e-03, atol=1e-05) @@ -578,77 +560,6 @@ def test_l2norm_forward(backend_type): model_output[0], rewrite_output, rtol=1e-03, atol=1e-05) -def test_predict_by_feat_of_fcos_head_ncnn(): - backend_type = Backend.NCNN - check_backend(backend_type) - fcos_head = get_fcos_head_model() - fcos_head.cpu().eval() - s = 128 - batch_img_metas = [{ - 'scale_factor': np.ones(4), - 'pad_shape': (s, s, 3), - 'img_shape': (s, s, 3) - }] - - output_names = ['detection_output'] - deploy_cfg = Config( - dict( - backend_config=dict(type=backend_type.value), - onnx_config=dict(output_names=output_names, input_shape=None), - codebase_config=dict( - type='mmdet', - task='ObjectDetection', - model_type='ncnn_end2end', - post_processing=dict( - score_threshold=0.05, - iou_threshold=0.5, - max_output_boxes_per_class=200, - pre_top_k=5000, - keep_top_k=100, - background_label_id=-1, - )))) - - # the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16), - # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2). - # the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16), - # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2) - seed_everything(1234) - cls_score = [ - torch.rand(1, fcos_head.num_classes, pow(2, i), pow(2, i)) - for i in range(5, 0, -1) - ] - seed_everything(5678) - bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)] - - seed_everything(9101) - centernesses = [ - torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1) - ] - - # to get outputs of onnx model after rewrite - batch_img_metas[0]['img_shape'] = torch.Tensor([s, s]) - wrapped_model = WrapModel( - fcos_head, - 'predict_by_feat', - batch_img_metas=batch_img_metas, - with_nms=True) - rewrite_inputs = { - 'cls_scores': cls_score, - 'bbox_preds': bboxes, - 'centernesses': centernesses - } - rewrite_outputs, is_backend_output = get_rewrite_outputs( - wrapped_model=wrapped_model, - model_inputs=rewrite_inputs, - deploy_cfg=deploy_cfg) - - # output should be of shape [1, N, 6] - if is_backend_output: - assert rewrite_outputs[0].shape[-1] == 6 - else: - assert rewrite_outputs.shape[-1] == 6 - - @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN]) def test_predict_by_feat_of_rpn_head(backend_type: Backend): check_backend(backend_type) diff --git a/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py b/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py index 99af5ee849..5e651c4f2c 100644 --- a/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py +++ b/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py @@ -57,6 +57,10 @@ def setup_class(cls): deploy_cfg=deploy_cfg, model_cfg=model_cfg) + @classmethod + def teardown_class(cls): + cls.wrapper.recover() + @pytest.mark.skipif( reason='Only support GPU test', condition=not torch.cuda.is_available()) diff --git a/tests/test_codebase/test_mmpose/test_mmpose_models.py b/tests/test_codebase/test_mmpose/test_mmpose_models.py index bce8c73942..eb22d35193 100644 --- a/tests/test_codebase/test_mmpose/test_mmpose_models.py +++ b/tests/test_codebase/test_mmpose/test_mmpose_models.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import numpy as np import pytest import torch @@ -93,30 +92,6 @@ def forward(self, x): return model -@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) -def test_cross_resolution_weighting_forward(backend_type: Backend): - check_backend(backend_type, True) - model = get_cross_resolution_weighting_model() - model.cpu().eval() - imgs = torch.rand(1, 16, 16, 16) - deploy_cfg = generate_mmpose_deploy_config(backend_type.value) - rewrite_inputs = {'x': imgs} - model_outputs = model.forward(imgs) - wrapped_model = WrapModel(model, 'forward') - rewrite_outputs, is_backend_output = get_rewrite_outputs( - wrapped_model=wrapped_model, - model_inputs=rewrite_inputs, - deploy_cfg=deploy_cfg) - if isinstance(rewrite_outputs, dict): - rewrite_outputs = rewrite_outputs['output'] - for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): - model_output = model_output.cpu().numpy() - if isinstance(rewrite_output, torch.Tensor): - rewrite_output = rewrite_output.detach().cpu().numpy() - assert np.allclose( - model_output, rewrite_output, rtol=1e-03, atol=1e-05) - - @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) def test_estimator_forward(backend_type: Backend): check_backend(backend_type, True) diff --git a/tests/test_core/test_function_rewriter.py b/tests/test_core/test_function_rewriter.py index d4e33a1857..157b6ca4b9 100644 --- a/tests/test_core/test_function_rewriter.py +++ b/tests/test_core/test_function_rewriter.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +try: + from torch.testing import assert_close as torch_assert_close +except Exception: + from torch.testing import assert_allclose as torch_assert_close + from mmdeploy.core import FUNCTION_REWRITER, RewriterContext from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter from mmdeploy.core.rewriters.rewriter_utils import collect_env @@ -26,19 +31,19 @@ def sub_func(x, y): with RewriterContext(cfg, backend='tensorrt'): result = torch.add(x, y) # replace add with sub - torch.testing.assert_allclose(result, x - y) + torch_assert_close(result, x - y) result = torch.mul(x, y) # replace add with sub - torch.testing.assert_allclose(result, x - y) + torch_assert_close(result, x - y) result = torch.add(x, y) # recovery origin function - torch.testing.assert_allclose(result, x + y) + torch_assert_close(result, x + y) with RewriterContext(cfg): result = torch.add(x, y) # replace should not happen with wrong backend - torch.testing.assert_allclose(result, x + y) + torch_assert_close(result, x + y) # test different config @FUNCTION_REWRITER.register_rewriter( @@ -49,16 +54,16 @@ def mul_func_class(x, y): with RewriterContext(cfg, backend='tensorrt'): result = x.add(y) # replace add with multi - torch.testing.assert_allclose(result, x * y) + torch_assert_close(result, x * y) result = x.add(y) # recovery origin function - torch.testing.assert_allclose(result, x + y) + torch_assert_close(result, x + y) with RewriterContext(cfg): result = x.add(y) # replace add with multi - torch.testing.assert_allclose(result, x * y) + torch_assert_close(result, x * y) # test origin_func @FUNCTION_REWRITER.register_rewriter( @@ -70,11 +75,15 @@ def origin_add_func(x, y, **kwargs): with RewriterContext(cfg): result = torch.add(x, y) # replace with origin + 1 - torch.testing.assert_allclose(result, x + y + 1) + torch_assert_close(result, x + y + 1) # remove torch.add del FUNCTION_REWRITER._origin_functions[-1] - torch.testing.assert_allclose(torch.add(x, y), x + y) + torch_assert_close(torch.add(x, y), x + y) + + FUNCTION_REWRITER._registry.remove_record(sub_func) + FUNCTION_REWRITER._registry.remove_record(mul_func_class) + FUNCTION_REWRITER._registry.remove_record(origin_add_func) def test_rewrite_empty_function(): diff --git a/tests/test_core/test_module_rewriter.py b/tests/test_core/test_module_rewriter.py index 001756d181..8cd0e0847e 100644 --- a/tests/test_core/test_module_rewriter.py +++ b/tests/test_core/test_module_rewriter.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +try: + from torch.testing import assert_close as torch_assert_close +except Exception: + from torch.testing import assert_allclose as torch_assert_close from mmdeploy.core import MODULE_REWRITER, patch_model @@ -29,7 +33,7 @@ def forward(self, *args, **kwargs): rewritten_model = patch_model(model, cfg=cfg, backend='tensorrt') rewritten_bottle_nect = rewritten_model.layer1[0] rewritten_result = rewritten_bottle_nect(x) - torch.testing.assert_allclose(rewritten_result, result * 2) + torch_assert_close(rewritten_result, result * 2) # wrong backend should not be rewritten model = resnet50().eval() @@ -38,7 +42,7 @@ def forward(self, *args, **kwargs): rewritten_model = patch_model(model, cfg=cfg) rewritten_bottle_nect = rewritten_model.layer1[0] rewritten_result = rewritten_bottle_nect(x) - torch.testing.assert_allclose(rewritten_result, result) + torch_assert_close(rewritten_result, result) def test_pass_redundant_args_to_model(): diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index a3daabd80f..6cee9acb3b 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -769,17 +769,13 @@ def test_gather(backend, assert importlib.util.find_spec('onnxruntime') is not None, 'onnxruntime \ not installed.' - import numpy as np - import onnxruntime - session = onnxruntime.InferenceSession(gather_model.SerializeToString()) - model_outputs = session.run( - output_names, - dict( - zip(input_names, [ - np.array(data, dtype=np.float32), - np.array(indice[0], dtype=np.int64) - ]))) - model_outputs = [model_output for model_output in model_outputs] + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_model = ORTWrapper( + gather_model.SerializeToString(), + device='cpu', + output_names=output_names) + model_outputs = ort_model(dict(zip(input_names, [data, indice[0]]))) + model_outputs = ort_model.output_to_list(model_outputs) ncnn_outputs = ncnn_model( dict(zip(input_names, [data.float(), indice.float()]))) From 9447b3e471f05ba8b7b58ad4bd5ed8bf5f0d10a5 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 21 Dec 2022 12:52:53 +0800 Subject: [PATCH 2/3] fix read string --- .../backend_ops/tensorrt/roi_align/trt_roi_align.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp index 8e4d556c14..8e69e357c6 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp @@ -201,11 +201,12 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin( if (field_name.compare("mode") == 0) { int data_size = fc->fields[i].length; + ASSERT(data_size > 0); const char *data_start = static_cast(fc->fields[i].data); - std::string poolModeStr(data_start, data_size); - if (strcmp(poolModeStr.c_str(), "avg") == 0) { + std::string poolModeStr(data_start); + if (poolModeStr == "avg") { poolMode = 1; - } else if (strcmp(poolModeStr.c_str(), "max") == 0) { + } else if (poolModeStr == "max") { poolMode = 0; } else { std::cout << "Unknown pool mode \"" << poolModeStr << "\"." << std::endl; From cb95042c578b7fdd3b7ca6c7c0067e8d369779bf Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 22 Dec 2022 10:48:59 +0800 Subject: [PATCH 3/3] snake --- .../backend_ops/tensorrt/roi_align/trt_roi_align.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp index 8e69e357c6..988893125d 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp @@ -203,13 +203,13 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin( int data_size = fc->fields[i].length; ASSERT(data_size > 0); const char *data_start = static_cast(fc->fields[i].data); - std::string poolModeStr(data_start); - if (poolModeStr == "avg") { + std::string pool_mode(data_start); + if (pool_mode == "avg") { poolMode = 1; - } else if (poolModeStr == "max") { + } else if (pool_mode == "max") { poolMode = 0; } else { - std::cout << "Unknown pool mode \"" << poolModeStr << "\"." << std::endl; + std::cout << "Unknown pool mode \"" << pool_mode << "\"." << std::endl; } ASSERT(poolMode >= 0); }