Skip to content

Commit

Permalink
Add with argmax in config for mmseg (#2038)
Browse files Browse the repository at this point in the history
* add with_argmax for model conversion in mmseg

* resolve lint
  • Loading branch information
AllentDan authored Apr 27, 2023
1 parent e9c0092 commit 5ebd10b
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 63 deletions.
2 changes: 1 addition & 1 deletion configs/mmseg/segmentation_rknn-fp16_static-320x320.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

onnx_config = dict(input_shape=[320, 320])

codebase_config = dict(model_type='rknn')
codebase_config = dict(with_argmax=False)

backend_config = dict(
input_size_list=[[3, 320, 320]],
Expand Down
2 changes: 1 addition & 1 deletion configs/mmseg/segmentation_rknn-int8_static-320x320.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

onnx_config = dict(input_shape=[320, 320])

codebase_config = dict(model_type='rknn')
codebase_config = dict(with_argmax=False)

backend_config = dict(input_size_list=[[3, 320, 320]])
2 changes: 1 addition & 1 deletion configs/mmseg/segmentation_static.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_base_ = ['../_base_/onnx_config.py']
codebase_config = dict(type='mmseg', task='Segmentation')
codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True)
2 changes: 2 additions & 0 deletions docs/en/04-supported-codebases/mmseg.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,5 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
- <i id="static_shape">PSPNet, Fast-SCNN</i> only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) is not supported by most inference backends.

- For models that only supports static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`.

- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs.
2 changes: 2 additions & 0 deletions docs/zh_cn/04-supported-codebases/mmseg.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,5 @@ cv2.imwrite('output_segmentation.png', img)
- <i id=“static_shape”>PSPNet,Fast-SCNN</i> 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。

- 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`

- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。
6 changes: 5 additions & 1 deletion mmdeploy/codebase/mmseg/deploy/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from mmengine.registry import Registry

from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
from mmdeploy.utils import (Codebase, Task, get_codebase_config,
get_input_shape, get_root_logger)


def process_model_config(model_cfg: mmengine.Config,
Expand Down Expand Up @@ -303,6 +304,9 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
if isinstance(params, list):
params = params[-1]
postprocess = dict(params=params, type='ResizeMask')
with_argmax = get_codebase_config(self.deploy_cfg).get(
'with_argmax', True)
postprocess['with_argmax'] = with_argmax
return postprocess

def get_model_name(self, *args, **kwargs) -> str:
Expand Down
36 changes: 3 additions & 33 deletions mmdeploy/codebase/mmseg/deploy/segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def pack_result(self, batch_outputs: torch.Tensor,
for seg_pred, data_sample in zip(batch_outputs, data_samples):
# resize seg_pred to original image shape
metainfo = data_sample.metainfo
if get_codebase_config(self.deploy_cfg).get('with_argmax',
True) is False:
seg_pred = seg_pred.argmax(dim=0, keepdim=True)
if metainfo['ori_shape'] != metainfo['img_shape']:
from mmseg.models.utils import resize
ori_type = seg_pred.dtype
Expand All @@ -119,39 +122,6 @@ def pack_result(self, batch_outputs: torch.Tensor,
return predictions


@__BACKEND_MODEL.register_module('rknn')
class RKNNModel(End2EndModel):
"""SDK inference class, converts RKNN output to mmseg format."""

def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'predict'):
"""Run forward inference.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (list[:obj:`SegDataSample`]): The seg data
samples. It usually includes information such as
`metainfo` and `gt_sem_seg`. Default to None.
Returns:
list: A list contains predictions.
"""
assert mode == 'predict', \
'Backend model only support mode==predict,' f' but get {mode}'
if inputs.device != torch.device(self.device):
get_root_logger().warning(f'expect input device {self.device}'
f' but get {inputs.device}.')
inputs = inputs.to(self.device)
batch_outputs = self.wrapper({self.input_name: inputs})
batch_outputs = [
output.argmax(dim=1, keepdim=True)
for output in batch_outputs.values()
]
return self.pack_result(batch_outputs[0], data_samples)


@__BACKEND_MODEL.register_module('vacc_seg')
class VACCModel(End2EndModel):
"""SDK inference class, converts VACC output to mmseg format."""
Expand Down
31 changes: 5 additions & 26 deletions mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.utils.constants import Backend
from mmdeploy.utils import get_codebase_config


@FUNCTION_REWRITER.register_rewriter(
Expand All @@ -26,6 +26,10 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
x = self.extract_feat(inputs)
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)

ctx = FUNCTION_REWRITER.get_context()
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
return seg_logit

# mark seg_head
@mark('decode_head', outputs=['output'])
def __mark_seg_logit(seg_logit):
Expand All @@ -35,28 +39,3 @@ def __mark_seg_logit(seg_logit):

seg_pred = seg_logit.argmax(dim=1, keepdim=True)
return seg_pred


@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.EncoderDecoder.predict',
backend=Backend.RKNN.value)
def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs):
"""Rewrite `predict` for RKNN backend.
Early return to avoid argmax operator.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (SampleList): The seg data samples.
Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
"""
batch_img_metas = []
for data_sample in data_samples:
batch_img_metas.append(data_sample.metainfo)
x = self.extract_feat(inputs)
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
return seg_logit

0 comments on commit 5ebd10b

Please sign in to comment.