Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Additional arguments support for OpenVINO Model Optimizer #178

Merged
merged 17 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/en/backends/openvino.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,28 @@ Notes:
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
- Models "VFNet" and "Faster R-CNN + DCN" use the custom "DeformableConv2D" operation.

### Deployment config

With the deployment config, you can specify additional options for the Model Optimizer.
To do this, add the necessary parameters to the `backend_config.mo_options` in the fields `args` (for parameters with values) and `flags` (for flags).

Example:
```python
backend_config = dict(
mo_options=dict(
args=dict({
'--mean_values': [0, 0, 0],
'--scale_values': [255, 255, 255],
'--data_type': 'FP32',
}),
flags=['--disable_fusing'],
)
)
```

Information about the possible parameters for the Model Optimizer can be found in the [documentation](https://docs.openvino.ai/latest/openvino_docs_MO_DG_prepare_model_convert_model_Converting_Model.html).


### FAQs

- None
2 changes: 1 addition & 1 deletion docs/en/codebases/mmpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

Please refer to [official installation guide](https://mmpose.readthedocs.io/en/latest/install.html) to install the codebase.

## MMEditing models support
## MMPose models support

| Model | Task | ONNX Runtime | TensorRT | NCNN | PPLNN | OpenVINO | Model config |
|:----------|:--------------|:------------:|:--------:|:----:|:-----:|:--------:|:-------------------------------------------------------------------------------------------:|
Expand Down
14 changes: 14 additions & 0 deletions mmdeploy/apis/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ def create_calib_table(calib_file: str,
**kwargs) -> None:
"""Create calibration table.

Examples:
>>> from mmdeploy.apis import create_calib_table
>>> from mmdeploy.utils import get_calib_filename, load_config
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_tensorrt-int8_dynamic-320x320-1344x1344.py'
>>> deploy_cfg = load_config(deploy_cfg)[0]
>>> calib_file = get_calib_filename(deploy_cfg)
>>> model_cfg = 'mmdetection/configs/fcos/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
>>> model_checkpoint = 'checkpoints/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth'
>>> create_calib_table(calib_file, deploy_cfg, \
model_cfg, model_checkpoint, device='cuda:0')

Args:
calib_file (str): Input calibration file.
deploy_cfg (str | mmcv.Config): Deployment config.
Expand Down
24 changes: 24 additions & 0 deletions mmdeploy/apis/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@ def extract_model(model: Union[str, onnx.ModelProto],
The sub-model is defined by the names of the input and output tensors
exactly.

Examples:
>>> from mmdeploy.apis import extract_model
>>> model = 'work_dir/fastrcnn.onnx'
>>> start = 'detector:input'
>>> end = ['extract_feat:output', 'multiclass_nms[0]:input']
>>> dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'scores': {
0: 'batch',
1: 'num_boxes',
},
'boxes': {
0: 'batch',
1: 'num_boxes',
}
}
>>> save_file = 'partition_model.onnx'
>>> extract_model(model, start, end, dynamic_axes=dynamic_axes, \
save_file=save_file)

Args:
model (str | onnx.ModelProto): Input ONNX model to be extracted.
start (str | Sequence[str]): Start marker(s) to extract.
Expand Down
12 changes: 12 additions & 0 deletions mmdeploy/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ def inference_model(model_cfg: Union[str, mmcv.Config],
device: str) -> Any:
"""Run inference with PyTorch or backend model and show results.

Examples:
>>> from mmdeploy.apis import inference_model
>>> model_cfg = 'mmdetection/configs/fcos/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_onnxruntime_dynamic.py'
>>> backend_files = ['work_dir/fcos.onnx']
>>> img = 'demo.jpg'
>>> device = 'cpu'
>>> model_output = inference_model(model_cfg, deploy_cfg, \
backend_files, img, device)

Args:
model_cfg (str | mmcv.Config): Model config file or Config object.
deploy_cfg (str | mmcv.Config): Deployment config file or Config
Expand Down
5 changes: 3 additions & 2 deletions mmdeploy/apis/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
if is_available():
from mmdeploy.backend.openvino.onnx2openvino import (get_output_model_file,
onnx2openvino)
from .utils import get_input_info_from_cfg
from .utils import get_input_info_from_cfg, get_mo_options_from_cfg
__all__ += [
'onnx2openvino', 'get_output_model_file', 'get_input_info_from_cfg'
'onnx2openvino', 'get_output_model_file', 'get_input_info_from_cfg',
'get_mo_options_from_cfg'
]
19 changes: 18 additions & 1 deletion mmdeploy/apis/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import mmcv

from mmdeploy.backend.openvino import ModelOptimizerOptions
from mmdeploy.utils import get_model_inputs
from mmdeploy.utils.config_utils import get_ir_config
from mmdeploy.utils.config_utils import get_backend_config, get_ir_config


def update_input_names(input_info: Dict[str, List],
Expand Down Expand Up @@ -50,3 +51,19 @@ def get_input_info_from_cfg(deploy_cfg: mmcv.Config) -> Dict[str, List]:
input_info = dict(zip(input_names, input_info))
input_info = update_input_names(input_info, input_names)
return input_info


def get_mo_options_from_cfg(deploy_cfg: mmcv.Config) -> ModelOptimizerOptions:
"""Get additional parameters for the Model Optimizer from the deploy
config.

Args:
deploy_cfg (mmcv.Config): Deployment config.

Returns:
ModelOptimizerOptions: A class that will contain additional arguments.
"""
backend_config = get_backend_config(deploy_cfg)
mo_options = backend_config.get('mo_options', None)
mo_options = ModelOptimizerOptions(mo_options)
return mo_options
15 changes: 15 additions & 0 deletions mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ def torch2onnx(img: Any,
device: str = 'cuda:0'):
"""Convert PyTorch model to ONNX model.

Examples:
>>> from mmdeploy.apis import torch2onnx
>>> img = 'demo.jpg'
>>> work_dir = 'work_dir'
>>> save_file = 'fcos.onnx'
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_onnxruntime_dynamic.py'
>>> model_cfg = 'mmdetection/configs/fcos/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
>>> model_checkpoint = 'checkpoints/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth'
>>> device = 'cpu'
>>> torch2onnx(img, work_dir, save_file, deploy_cfg, \
model_cfg, model_checkpoint, device)

Args:
img (str | np.ndarray | torch.Tensor): Input image used to assist
converting model.
Expand Down
12 changes: 12 additions & 0 deletions mmdeploy/apis/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
show_result: bool = False):
"""Run inference with PyTorch or backend model and show results.

Examples:
>>> from mmdeploy.apis import visualize_model
>>> model_cfg = 'mmdetection/configs/fcos/' \
'fcos_r50_caffe_fpn_gn-head_1x_coco.py'
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_onnxruntime_dynamic.py'
>>> model = 'work_dir/fcos.onnx'
>>> img = 'demo.jpg'
>>> device = 'cpu'
>>> visualize_model(model_cfg, deploy_cfg, model, \
img, device, show_result=True)

Args:
model_cfg (str | mmcv.Config): Model config file or Config object.
deploy_cfg (str | mmcv.Config): Deployment config file or Config
Expand Down
7 changes: 7 additions & 0 deletions mmdeploy/backend/ncnn/onnx2ncnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def onnx2ncnn(onnx_path: str, save_param: str, save_bin: str):
a executable program to convert the `.onnx` file to a `.param` file and
a `.bin` file. The output files will save to work_dir.

Example:
>>> from mmdeploy.backend.ncnn.onnx2ncnn import onnx2ncnn
>>> onnx_path = 'work_dir/end2end.onnx'
>>> save_param = 'work_dir/end2end.param'
>>> save_bin = 'work_dir/end2end.bin'
>>> onnx2ncnn(onnx_path, save_param, save_bin)

Args:
onnx_path (str): The path of the onnx model.
save_param (str): The path to save the output `.param` file.
Expand Down
5 changes: 4 additions & 1 deletion mmdeploy/backend/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,8 @@ def is_available() -> bool:

if is_available():
from .onnx2openvino import get_output_model_file
from .utils import ModelOptimizerOptions
from .wrapper import OpenVINOWrapper
__all__ = ['OpenVINOWrapper', 'get_output_model_file']
__all__ = [
'OpenVINOWrapper', 'get_output_model_file', 'ModelOptimizerOptions'
]
25 changes: 20 additions & 5 deletions mmdeploy/backend/openvino/onnx2openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import os.path as osp
import subprocess
from subprocess import PIPE, CalledProcessError, run
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import mmcv
import torch

from mmdeploy.utils import get_root_logger
from .utils import ModelOptimizerOptions


def get_mo_command() -> str:
Expand Down Expand Up @@ -55,17 +56,29 @@ def get_output_model_file(onnx_path: str, work_dir: str) -> str:


def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]],
output_names: List[str], onnx_path: str, work_dir: str):
output_names: List[str],
onnx_path: str,
work_dir: str,
mo_options: Optional[ModelOptimizerOptions] = None):
"""Convert ONNX to OpenVINO.

Examples:
>>> from mmdeploy.backend.openvino.onnx2openvino import onnx2openvino
>>> input_info = {'input': [1,3,800,1344]}
>>> output_names = ['dets', 'labels']
>>> onnx_path = 'work_dir/end2end.onnx'
>>> work_dir = 'work_dir'
>>> onnx2openvino(input_info, output_names, onnx_path, work_dir)

Args:
input_info (Dict[str, Union[List[int], torch.Size]]):
The shape of each input.
output_names (List[str]): Output names. Example: ['dets', 'labels'].
onnx_path (str): The path to the onnx model.
work_dir (str): The path to the directory for saving the results.
mo_options (None | ModelOptimizerOptions): The class with
additional arguments for the Model Optimizer.
"""

input_names = ','.join(input_info.keys())
input_shapes = ','.join(str(list(elem)) for elem in input_info.values())
output = ','.join(output_names)
Expand All @@ -80,8 +93,10 @@ def onnx2openvino(input_info: Dict[str, Union[List[int], torch.Size]],
f'--output_dir="{work_dir}" ' \
f'--output="{output}" ' \
f'--input="{input_names}" ' \
f'--input_shape="{input_shapes}" ' \
f'--disable_fusing '
f'--input_shape="{input_shapes}" '
if mo_options is not None:
mo_args += mo_options.get_options()

command = f'{mo_command} {mo_args}'

logger = get_root_logger()
Expand Down
45 changes: 45 additions & 0 deletions mmdeploy/backend/openvino/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union


class ModelOptimizerOptions:
"""A class to make it easier to support additional arguments for the Model
Optimizer that can be passed through the deployment configuration.

Example:
>>> deploy_cfg = load_config(deploy_cfg_path)
>>> mo_options = deploy_cfg.get('mo_options', None)
>>> mo_options = ModelOptimizerOptions(mo_options)
>>> mo_args = mo_options.get_options()
"""

def __init__(self,
mo_options: Optional[Dict[str, Union[Dict, List]]] = None):
self.args = ''
self.flags = ''
if mo_options is not None:
self.args = self.__parse_args(mo_options)
self.flags = self.__parse_flags(mo_options)

def __parse_args(self, mo_options: Dict[str, Union[Dict, List]]) -> str:
"""Parses a dictionary with arguments into a string."""
mo_args_str = ''
if 'args' in mo_options:
for key, value in mo_options['args'].items():
value_str = f'"{value}"' if isinstance(value, list) else value
mo_args_str += f'{key}={value_str} '
return mo_args_str

def __parse_flags(self, mo_options: Dict[str, Union[Dict, List]]) -> str:
"""Parses a list with flags into a string."""
mo_flags_str = ''
if 'flags' in mo_options:
mo_flags_str += ' '.join(mo_options['flags'])
return mo_flags_str

def get_options(self) -> str:
"""Returns a string with additional arguments for the Model Optimizer.

If there are no additional arguments, it will return an empty string.
"""
return self.args + self.flags
11 changes: 11 additions & 0 deletions mmdeploy/backend/tensorrt/onnx2tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def onnx2tensorrt(work_dir: str,
**kwargs):
"""Convert ONNX to TensorRT.

Examples:
>>> from mmdeploy.backend.tensorrt.onnx2tensorrt import onnx2tensorrt
>>> work_dir = 'work_dir'
>>> save_file = 'end2end.engine'
>>> model_id = 0
>>> deploy_cfg = 'configs/mmdet/detection/' \
'detection_tensorrt_dynamic-320x320-1344x1344.py'
>>> onnx_model = 'work_dir/end2end.onnx'
>>> onnx2tensorrt(work_dir, save_file, model_id, deploy_cfg, \
onnx_model, 'cuda:0')

Args:
work_dir (str): A working directory.
save_file (str): The base name of the file to save TensorRT engine.
Expand Down
4 changes: 3 additions & 1 deletion mmdeploy/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,15 +487,17 @@ def get_backend_outputs(ir_file_path: str,
import mmdeploy.apis.openvino as openvino_apis
if not openvino_apis.is_available():
return None
from mmdeploy.apis.openvino import get_mo_options_from_cfg
openvino_work_dir = tempfile.TemporaryDirectory().name
openvino_file_path = openvino_apis.get_output_model_file(
ir_file_path, openvino_work_dir)
input_info = {
name: value.shape
for name, value in flatten_model_inputs.items()
}
mo_options = get_mo_options_from_cfg(deploy_cfg)
openvino_apis.onnx2openvino(input_info, output_names, ir_file_path,
openvino_work_dir)
openvino_work_dir, mo_options)
backend_files = [openvino_file_path]
backend_feats = flatten_model_inputs
device = 'cpu'
Expand Down
Loading