diff --git a/docs/en/backends/openvino.md b/docs/en/backends/openvino.md index 12a6686d36..a33d64d528 100644 --- a/docs/en/backends/openvino.md +++ b/docs/en/backends/openvino.md @@ -63,6 +63,28 @@ Notes: the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph. - Models "VFNet" and "Faster R-CNN + DCN" use the custom "DeformableConv2D" operation. +### Deployment config + +With the deployment config, you can specify additional options for the Model Optimizer. +To do this, add the necessary parameters to the `backend_config.mo_options` in the fields `args` (for parameters with values) and `flags` (for flags). + +Example: +```python +backend_config = dict( + mo_options=dict( + args=dict({ + '--mean_values': [0, 0, 0], + '--scale_values': [255, 255, 255], + '--data_type': 'FP32', + }), + flags=['--disable_fusing'], + ) +) +``` + +Information about the possible parameters for the Model Optimizer can be found in the [documentation](https://docs.openvino.ai/latest/openvino_docs_MO_DG_prepare_model_convert_model_Converting_Model.html). + + ### FAQs - None diff --git a/docs/en/codebases/mmpose.md b/docs/en/codebases/mmpose.md index 3851782613..1dd0f9b404 100644 --- a/docs/en/codebases/mmpose.md +++ b/docs/en/codebases/mmpose.md @@ -6,7 +6,7 @@ Please refer to [official installation guide](https://mmpose.readthedocs.io/en/latest/install.html) to install the codebase. -## MMEditing models support +## MMPose models support | Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config | |:----------|:--------------|:------------:|:--------:|:----:|:-----:|:--------:|:-------------------------------------------------------------------------------------------:| diff --git a/mmdeploy/apis/calibration.py b/mmdeploy/apis/calibration.py index 69b5ba6dda..1939d502fa 100644 --- a/mmdeploy/apis/calibration.py +++ b/mmdeploy/apis/calibration.py @@ -21,6 +21,20 @@ def create_calib_table(calib_file: str, **kwargs) -> None: """Create calibration table. + Examples: + >>> from mmdeploy.apis import create_calib_table + >>> from mmdeploy.utils import get_calib_filename, load_config + >>> deploy_cfg = 'configs/mmdet/detection/' \ + 'detection_tensorrt-int8_dynamic-320x320-1344x1344.py' + >>> deploy_cfg = load_config(deploy_cfg)[0] + >>> calib_file = get_calib_filename(deploy_cfg) + >>> model_cfg = 'mmdetection/configs/fcos/' \ + 'fcos_r50_caffe_fpn_gn-head_1x_coco.py' + >>> model_checkpoint = 'checkpoints/' \ + 'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth' + >>> create_calib_table(calib_file, deploy_cfg, \ + model_cfg, model_checkpoint, device='cuda:0') + Args: calib_file (str): Input calibration file. deploy_cfg (str | mmcv.Config): Deployment config. diff --git a/mmdeploy/apis/extract_model.py b/mmdeploy/apis/extract_model.py index 57da573963..a01afd59b9 100644 --- a/mmdeploy/apis/extract_model.py +++ b/mmdeploy/apis/extract_model.py @@ -23,6 +23,30 @@ def extract_model(model: Union[str, onnx.ModelProto], The sub-model is defined by the names of the input and output tensors exactly. + Examples: + >>> from mmdeploy.apis import extract_model + >>> model = 'work_dir/fastrcnn.onnx' + >>> start = 'detector:input' + >>> end = ['extract_feat:output', 'multiclass_nms[0]:input'] + >>> dynamic_axes = { + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'scores': { + 0: 'batch', + 1: 'num_boxes', + }, + 'boxes': { + 0: 'batch', + 1: 'num_boxes', + } + } + >>> save_file = 'partition_model.onnx' + >>> extract_model(model, start, end, dynamic_axes=dynamic_axes, \ + save_file=save_file) + Args: model (str | onnx.ModelProto): Input ONNX model to be extracted. start (str | Sequence[str]): Start marker(s) to extract. diff --git a/mmdeploy/apis/inference.py b/mmdeploy/apis/inference.py index 36a274dccf..47bd204322 100644 --- a/mmdeploy/apis/inference.py +++ b/mmdeploy/apis/inference.py @@ -14,6 +14,18 @@ def inference_model(model_cfg: Union[str, mmcv.Config], device: str) -> Any: """Run inference with PyTorch or backend model and show results. + Examples: + >>> from mmdeploy.apis import inference_model + >>> model_cfg = 'mmdetection/configs/fcos/' \ + 'fcos_r50_caffe_fpn_gn-head_1x_coco.py' + >>> deploy_cfg = 'configs/mmdet/detection/' \ + 'detection_onnxruntime_dynamic.py' + >>> backend_files = ['work_dir/fcos.onnx'] + >>> img = 'demo.jpg' + >>> device = 'cpu' + >>> model_output = inference_model(model_cfg, deploy_cfg, \ + backend_files, img, device) + Args: model_cfg (str | mmcv.Config): Model config file or Config object. deploy_cfg (str | mmcv.Config): Deployment config file or Config diff --git a/mmdeploy/apis/openvino/__init__.py b/mmdeploy/apis/openvino/__init__.py index f7fbe9a370..97f6ade95d 100644 --- a/mmdeploy/apis/openvino/__init__.py +++ b/mmdeploy/apis/openvino/__init__.py @@ -6,7 +6,8 @@ if is_available(): from mmdeploy.backend.openvino.onnx2openvino import (get_output_model_file, onnx2openvino) - from .utils import get_input_info_from_cfg + from .utils import get_input_info_from_cfg, get_mo_options_from_cfg __all__ += [ - 'onnx2openvino', 'get_output_model_file', 'get_input_info_from_cfg' + 'onnx2openvino', 'get_output_model_file', 'get_input_info_from_cfg', + 'get_mo_options_from_cfg' ] diff --git a/mmdeploy/apis/openvino/utils.py b/mmdeploy/apis/openvino/utils.py index 79710eff21..72317595fd 100644 --- a/mmdeploy/apis/openvino/utils.py +++ b/mmdeploy/apis/openvino/utils.py @@ -3,8 +3,9 @@ import mmcv +from mmdeploy.backend.openvino import ModelOptimizerOptions from mmdeploy.utils import get_model_inputs -from mmdeploy.utils.config_utils import get_ir_config +from mmdeploy.utils.config_utils import get_backend_config, get_ir_config def update_input_names(input_info: Dict[str, List], @@ -50,3 +51,19 @@ def get_input_info_from_cfg(deploy_cfg: mmcv.Config) -> Dict[str, List]: input_info = dict(zip(input_names, input_info)) input_info = update_input_names(input_info, input_names) return input_info + + +def get_mo_options_from_cfg(deploy_cfg: mmcv.Config) -> ModelOptimizerOptions: + """Get additional parameters for the Model Optimizer from the deploy + config. + + Args: + deploy_cfg (mmcv.Config): Deployment config. + + Returns: + ModelOptimizerOptions: A class that will contain additional arguments. + """ + backend_config = get_backend_config(deploy_cfg) + mo_options = backend_config.get('mo_options', None) + mo_options = ModelOptimizerOptions(mo_options) + return mo_options diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index f627c9a346..94f3047b79 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -62,6 +62,21 @@ def torch2onnx(img: Any, device: str = 'cuda:0'): """Convert PyTorch model to ONNX model. + Examples: + >>> from mmdeploy.apis import torch2onnx + >>> img = 'demo.jpg' + >>> work_dir = 'work_dir' + >>> save_file = 'fcos.onnx' + >>> deploy_cfg = 'configs/mmdet/detection/' \ + 'detection_onnxruntime_dynamic.py' + >>> model_cfg = 'mmdetection/configs/fcos/' \ + 'fcos_r50_caffe_fpn_gn-head_1x_coco.py' + >>> model_checkpoint = 'checkpoints/' \ + 'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth' + >>> device = 'cpu' + >>> torch2onnx(img, work_dir, save_file, deploy_cfg, \ + model_cfg, model_checkpoint, device) + Args: img (str | np.ndarray | torch.Tensor): Input image used to assist converting model. diff --git a/mmdeploy/apis/visualize.py b/mmdeploy/apis/visualize.py index e33c1892a8..ade0a21fe8 100644 --- a/mmdeploy/apis/visualize.py +++ b/mmdeploy/apis/visualize.py @@ -19,6 +19,18 @@ def visualize_model(model_cfg: Union[str, mmcv.Config], show_result: bool = False): """Run inference with PyTorch or backend model and show results. + Examples: + >>> from mmdeploy.apis import visualize_model + >>> model_cfg = 'mmdetection/configs/fcos/' \ + 'fcos_r50_caffe_fpn_gn-head_1x_coco.py' + >>> deploy_cfg = 'configs/mmdet/detection/' \ + 'detection_onnxruntime_dynamic.py' + >>> model = 'work_dir/fcos.onnx' + >>> img = 'demo.jpg' + >>> device = 'cpu' + >>> visualize_model(model_cfg, deploy_cfg, model, \ + img, device, show_result=True) + Args: model_cfg (str | mmcv.Config): Model config file or Config object. deploy_cfg (str | mmcv.Config): Deployment config file or Config diff --git a/mmdeploy/backend/ncnn/onnx2ncnn.py b/mmdeploy/backend/ncnn/onnx2ncnn.py index 386a93a97e..f4bd30d800 100644 --- a/mmdeploy/backend/ncnn/onnx2ncnn.py +++ b/mmdeploy/backend/ncnn/onnx2ncnn.py @@ -33,6 +33,13 @@ def onnx2ncnn(onnx_path: str, save_param: str, save_bin: str): a executable program to convert the `.onnx` file to a `.param` file and a `.bin` file. The output files will save to work_dir. + Example: + >>> from mmdeploy.backend.ncnn.onnx2ncnn import onnx2ncnn + >>> onnx_path = 'work_dir/end2end.onnx' + >>> save_param = 'work_dir/end2end.param' + >>> save_bin = 'work_dir/end2end.bin' + >>> onnx2ncnn(onnx_path, save_param, save_bin) + Args: onnx_path (str): The path of the onnx model. save_param (str): The path to save the output `.param` file. diff --git a/mmdeploy/backend/openvino/__init__.py b/mmdeploy/backend/openvino/__init__.py index cb084b5589..7314e48df0 100644 --- a/mmdeploy/backend/openvino/__init__.py +++ b/mmdeploy/backend/openvino/__init__.py @@ -13,5 +13,8 @@ def is_available() -> bool: if is_available(): from .onnx2openvino import get_output_model_file + from .utils import ModelOptimizerOptions from .wrapper import OpenVINOWrapper - __all__ = ['OpenVINOWrapper', 'get_output_model_file'] + __all__ = [ + 'OpenVINOWrapper', 'get_output_model_file', 'ModelOptimizerOptions' + ] diff --git a/mmdeploy/backend/openvino/onnx2openvino.py b/mmdeploy/backend/openvino/onnx2openvino.py index eb592cc1c0..7252efabbd 100644 --- a/mmdeploy/backend/openvino/onnx2openvino.py +++ b/mmdeploy/backend/openvino/onnx2openvino.py @@ -2,12 +2,13 @@ import os.path as osp import subprocess from subprocess import PIPE, CalledProcessError, run -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import mmcv import torch from mmdeploy.utils import get_root_logger +from .utils import ModelOptimizerOptions def get_mo_command() -> str: @@ -55,17 +56,29 @@ def get_output_model_file(onnx_path: str, work_dir: str) -> str: def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]], - output_names: List[str], onnx_path: str, work_dir: str): + output_names: List[str], + onnx_path: str, + work_dir: str, + mo_options: Optional[ModelOptimizerOptions] = None): """Convert ONNX to OpenVINO. + Examples: + >>> from mmdeploy.backend.openvino.onnx2openvino import onnx2openvino + >>> input_info = {'input': [1,3,800,1344]} + >>> output_names = ['dets', 'labels'] + >>> onnx_path = 'work_dir/end2end.onnx' + >>> work_dir = 'work_dir' + >>> onnx2openvino(input_info, output_names, onnx_path, work_dir) + Args: input_info (Dict[str, Union[List[int], torch.Size]]): The shape of each input. output_names (List[str]): Output names. Example: ['dets', 'labels']. onnx_path (str): The path to the onnx model. work_dir (str): The path to the directory for saving the results. + mo_options (None | ModelOptimizerOptions): The class with + additional arguments for the Model Optimizer. """ - input_names = ','.join(input_info.keys()) input_shapes = ','.join(str(list(elem)) for elem in input_info.values()) output = ','.join(output_names) @@ -80,8 +93,10 @@ def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]], f'--output_dir="{work_dir}" ' \ f'--output="{output}" ' \ f'--input="{input_names}" ' \ - f'--input_shape="{input_shapes}" ' \ - f'--disable_fusing ' + f'--input_shape="{input_shapes}" ' + if mo_options is not None: + mo_args += mo_options.get_options() + command = f'{mo_command} {mo_args}' logger = get_root_logger() diff --git a/mmdeploy/backend/openvino/utils.py b/mmdeploy/backend/openvino/utils.py new file mode 100644 index 0000000000..7aa9dc3b37 --- /dev/null +++ b/mmdeploy/backend/openvino/utils.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + + +class ModelOptimizerOptions: + """A class to make it easier to support additional arguments for the Model + Optimizer that can be passed through the deployment configuration. + + Example: + >>> deploy_cfg = load_config(deploy_cfg_path) + >>> mo_options = deploy_cfg.get('mo_options', None) + >>> mo_options = ModelOptimizerOptions(mo_options) + >>> mo_args = mo_options.get_options() + """ + + def __init__(self, + mo_options: Optional[Dict[str, Union[Dict, List]]] = None): + self.args = '' + self.flags = '' + if mo_options is not None: + self.args = self.__parse_args(mo_options) + self.flags = self.__parse_flags(mo_options) + + def __parse_args(self, mo_options: Dict[str, Union[Dict, List]]) -> str: + """Parses a dictionary with arguments into a string.""" + mo_args_str = '' + if 'args' in mo_options: + for key, value in mo_options['args'].items(): + value_str = f'"{value}"' if isinstance(value, list) else value + mo_args_str += f'{key}={value_str} ' + return mo_args_str + + def __parse_flags(self, mo_options: Dict[str, Union[Dict, List]]) -> str: + """Parses a list with flags into a string.""" + mo_flags_str = '' + if 'flags' in mo_options: + mo_flags_str += ' '.join(mo_options['flags']) + return mo_flags_str + + def get_options(self) -> str: + """Returns a string with additional arguments for the Model Optimizer. + + If there are no additional arguments, it will return an empty string. + """ + return self.args + self.flags diff --git a/mmdeploy/backend/tensorrt/onnx2tensorrt.py b/mmdeploy/backend/tensorrt/onnx2tensorrt.py index b0f93512dc..f0e316e468 100644 --- a/mmdeploy/backend/tensorrt/onnx2tensorrt.py +++ b/mmdeploy/backend/tensorrt/onnx2tensorrt.py @@ -21,6 +21,17 @@ def onnx2tensorrt(work_dir: str, **kwargs): """Convert ONNX to TensorRT. + Examples: + >>> from mmdeploy.backend.tensorrt.onnx2tensorrt import onnx2tensorrt + >>> work_dir = 'work_dir' + >>> save_file = 'end2end.engine' + >>> model_id = 0 + >>> deploy_cfg = 'configs/mmdet/detection/' \ + 'detection_tensorrt_dynamic-320x320-1344x1344.py' + >>> onnx_model = 'work_dir/end2end.onnx' + >>> onnx2tensorrt(work_dir, save_file, model_id, deploy_cfg, \ + onnx_model, 'cuda:0') + Args: work_dir (str): A working directory. save_file (str): The base name of the file to save TensorRT engine. diff --git a/mmdeploy/utils/test.py b/mmdeploy/utils/test.py index c912f14821..5234d3e495 100644 --- a/mmdeploy/utils/test.py +++ b/mmdeploy/utils/test.py @@ -487,6 +487,7 @@ def get_backend_outputs(ir_file_path: str, import mmdeploy.apis.openvino as openvino_apis if not openvino_apis.is_available(): return None + from mmdeploy.apis.openvino import get_mo_options_from_cfg openvino_work_dir = tempfile.TemporaryDirectory().name openvino_file_path = openvino_apis.get_output_model_file( ir_file_path, openvino_work_dir) @@ -494,8 +495,9 @@ def get_backend_outputs(ir_file_path: str, name: value.shape for name, value in flatten_model_inputs.items() } + mo_options = get_mo_options_from_cfg(deploy_cfg) openvino_apis.onnx2openvino(input_info, output_names, ir_file_path, - openvino_work_dir) + openvino_work_dir, mo_options) backend_files = [openvino_file_path] backend_feats = flatten_model_inputs device = 'cpu' diff --git a/tests/test_apis/test_onnx2openvino.py b/tests/test_apis/test_onnx2openvino.py index 43fa623cd2..885d00b312 100644 --- a/tests/test_apis/test_onnx2openvino.py +++ b/tests/test_apis/test_onnx2openvino.py @@ -60,9 +60,28 @@ def get_outputs(pytorch_model, openvino_model_path, input, input_name, return output_pytorch, openvino_output +def get_base_deploy_cfg(): + deploy_cfg = mmcv.Config(dict(backend_config=dict(type='openvino'))) + return deploy_cfg + + +def get_deploy_cfg_with_mo_args(): + deploy_cfg = mmcv.Config( + dict( + backend_config=dict( + type='openvino', + mo_options=dict( + args={'--data_type': 'FP32'}, flags=['--disable_fusing' + ])))) + return deploy_cfg + + +@pytest.mark.parametrize('get_deploy_cfg', + [get_base_deploy_cfg, get_deploy_cfg_with_mo_args]) @backend_checker(Backend.OPENVINO) -def test_onnx2openvino(): - from mmdeploy.apis.openvino import get_output_model_file, onnx2openvino +def test_onnx2openvino(get_deploy_cfg): + from mmdeploy.apis.openvino import (get_mo_options_from_cfg, + get_output_model_file, onnx2openvino) pytorch_model = TestModel().eval() export_img = torch.rand([1, 3, 8, 8]) onnx_file = tempfile.NamedTemporaryFile(suffix='.onnx').name @@ -74,7 +93,10 @@ def test_onnx2openvino(): input_info = {input_name: export_img.shape} output_names = [output_name] openvino_dir = tempfile.TemporaryDirectory().name - onnx2openvino(input_info, output_names, onnx_file, openvino_dir) + deploy_cfg = get_deploy_cfg() + mo_options = get_mo_options_from_cfg(deploy_cfg) + onnx2openvino(input_info, output_names, onnx_file, openvino_dir, + mo_options) openvino_model_path = get_output_model_file(onnx_file, openvino_dir) assert osp.exists(openvino_model_path), \ 'The file (.xml) for OpenVINO IR has not been created.' diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index f299e18641..706cc05be6 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -126,6 +126,7 @@ def get_l2norm_forward_model(): """L2Norm Neck Config.""" from mmdet.models.necks.ssd_neck import L2Norm model = L2Norm(16) + torch.nn.init.uniform_(model.weight) model.requires_grad_(False) return model diff --git a/tools/deploy.py b/tools/deploy.py index ee0dd7fb43..63100c757e 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -221,6 +221,7 @@ def main(): 'OpenVINO is not available, please install OpenVINO first.' from mmdeploy.apis.openvino import (get_input_info_from_cfg, + get_mo_options_from_cfg, get_output_model_file, onnx2openvino) openvino_files = [] @@ -228,10 +229,12 @@ def main(): model_xml_path = get_output_model_file(onnx_path, args.work_dir) input_info = get_input_info_from_cfg(deploy_cfg) output_names = get_ir_config(deploy_cfg).output_names + mo_options = get_mo_options_from_cfg(deploy_cfg) create_process( f'onnx2openvino with {onnx_path}', target=onnx2openvino, - args=(input_info, output_names, onnx_path, args.work_dir), + args=(input_info, output_names, onnx_path, args.work_dir, + mo_options), kwargs=dict(), ret_value=ret_value) openvino_files.append(model_xml_path)