From afd1751159b7bcd1ad29fafef8e4b3b49e0b9af9 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Thu, 17 Dec 2020 17:45:05 +0800 Subject: [PATCH 01/10] [Enhance]: add onxx simplify --- mmcv/onnx/__init__.py | 3 +- mmcv/onnx/simplify/__init__.py | 3 + mmcv/onnx/simplify/common.py | 43 ++++ mmcv/onnx/simplify/core.py | 448 +++++++++++++++++++++++++++++++++ setup.cfg | 2 +- 5 files changed, 497 insertions(+), 2 deletions(-) create mode 100644 mmcv/onnx/simplify/__init__.py create mode 100644 mmcv/onnx/simplify/common.py create mode 100644 mmcv/onnx/simplify/core.py diff --git a/mmcv/onnx/__init__.py b/mmcv/onnx/__init__.py index 1122c5f866..641999cb06 100644 --- a/mmcv/onnx/__init__.py +++ b/mmcv/onnx/__init__.py @@ -1,3 +1,4 @@ +from .simplify import simplify from .symbolic import register_extra_symbolics -__all__ = ['register_extra_symbolics'] +__all__ = ['register_extra_symbolics', 'simplify'] diff --git a/mmcv/onnx/simplify/__init__.py b/mmcv/onnx/simplify/__init__.py new file mode 100644 index 0000000000..d4498ab087 --- /dev/null +++ b/mmcv/onnx/simplify/__init__.py @@ -0,0 +1,3 @@ +from .core import simplify + +__all__ = ['simplify'] diff --git a/mmcv/onnx/simplify/common.py b/mmcv/onnx/simplify/common.py new file mode 100644 index 0000000000..701e59aab1 --- /dev/null +++ b/mmcv/onnx/simplify/common.py @@ -0,0 +1,43 @@ +import copy +import warnings + +import onnx + + +def add_suffix2name(ori_model, suffix='__', verify=True): + """Simplily add a suffix to the name of node, which has a numeric name.""" + # check if has special op, which has subgraph. + special_ops = ('If', 'Loop') + for node in ori_model.graph.node: + if node.op_type in special_ops: + warnings.warn(f'This model has special op: {node.op_type}.') + return ori_model + + model = copy.deepcopy(ori_model) + + def need_update(name): + return name.isnumeric() + + def update_name(nodes): + for node in nodes: + if need_update(node.name): + node.name += suffix + + update_name(model.graph.initializer) + update_name(model.graph.input) + update_name(model.graph.output) + + for i, node in enumerate(ori_model.graph.node): + # process input of node + for j, name in enumerate(node.input): + if need_update(name): + model.graph.node[i].input[j] = name + suffix + + # process output of node + for j, name in enumerate(node.output): + if need_update(name): + model.graph.node[i].output[j] = name + suffix + if verify: + onnx.checker.check_model(model) + + return model diff --git a/mmcv/onnx/simplify/core.py b/mmcv/onnx/simplify/core.py new file mode 100644 index 0000000000..d4c5c36d65 --- /dev/null +++ b/mmcv/onnx/simplify/core.py @@ -0,0 +1,448 @@ +# This file is modified from https://github.com/daquexian/onnx-simplifier +import copy +import os +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np # type: ignore +import onnx # type: ignore +import onnx.helper # type: ignore +import onnx.numpy_helper +import onnx.shape_inference # type: ignore +import onnxoptimizer # type: ignore +import onnxruntime as rt # type: ignore + +from .common import add_suffix2name + +TensorShape = List[int] +TensorShapes = Dict[Optional[str], TensorShape] + + +def add_features_to_output(m: onnx.ModelProto, + nodes: List[onnx.NodeProto]) -> None: + """Add features to output in pb, so that ONNX Runtime will output them. + + :param m: the model that will be run in ONNX Runtime + :param nodes: nodes whose outputs will be added into the graph outputs + """ + for node in nodes: + for output in node.output: + m.graph.output.extend([onnx.ValueInfoProto(name=output)]) + + +def get_shape_from_value_info_proto(v: onnx.ValueInfoProto) -> List[int]: + return [dim.dim_value for dim in v.type.tensor_type.shape.dim] + + +def get_value_info_all(m: onnx.ModelProto, + name: str) -> Optional[onnx.ValueInfoProto]: + for v in m.graph.value_info: + if v.name == name: + return v + + for v in m.graph.input: + if v.name == name: + return v + + for v in m.graph.output: + if v.name == name: + return v + + return None + + +def get_shape(m: onnx.ModelProto, name: str) -> TensorShape: + """ + Note: This method relies on onnx shape inference, which is not reliable. + So only use it on input or output tensors + """ + v = get_value_info_all(m, name) + if v is not None: + return get_shape_from_value_info_proto(v) + raise RuntimeError('Cannot get shape of "{}"'.format(name)) + + +def get_elem_type(m: onnx.ModelProto, name: str) -> Optional[int]: + v = get_value_info_all(m, name) + if v is not None: + return v.type.tensor_type.elem_type + return None + + +def get_np_type_from_elem_type(elem_type: int) -> int: + # from https://github.com/onnx/onnx/blob/ + # e5e9a539f550f07ec156812484e8d4f33fb91f88/onnx/onnx.proto#L461 + sizes = (None, np.float32, np.uint8, np.int8, np.uint16, np.int16, + np.int32, np.int64, str, np.bool, np.float16, np.double, + np.uint32, np.uint64, np.complex64, np.complex128, np.float16) + assert len(sizes) == 17 + size = sizes[elem_type] + assert size is not None + return size + + +def get_input_names(model: onnx.ModelProto) -> List[str]: + input_names = list( + set([ipt.name for ipt in model.graph.input]) - + set([x.name for x in model.graph.initializer])) + return input_names + + +def add_initializers_into_inputs(model: onnx.ModelProto) -> onnx.ModelProto: + for x in model.graph.initializer: + input_names = [x.name for x in model.graph.input] + if x.name not in input_names: + shape = onnx.TensorShapeProto() + for dim in x.dims: + shape.dim.extend( + [onnx.TensorShapeProto.Dimension(dim_value=dim)]) + model.graph.input.extend([ + onnx.ValueInfoProto( + name=x.name, + type=onnx.TypeProto( + tensor_type=onnx.TypeProto.Tensor( + elem_type=x.data_type, shape=shape))) + ]) + return model + + +def generate_rand_input(model, input_shapes: Optional[TensorShapes] = None): + if input_shapes is None: + input_shapes = {} + input_names = get_input_names(model) + full_input_shapes = {ipt: get_shape(model, ipt) for ipt in input_names} + assert None not in input_shapes + full_input_shapes.update(input_shapes) # type: ignore + for key in full_input_shapes: + if np.prod(full_input_shapes[key]) <= 0: + raise RuntimeError(f'The shape of input "{key}" has dynamic size, \ + please determine the input size manually.') + + inputs = { + ipt: np.array( + np.random.rand(*full_input_shapes[ipt]), + dtype=get_np_type_from_elem_type(get_elem_type(model, ipt))) + for ipt in input_names + } + return inputs + + +def get_constant_nodes(m: onnx.ModelProto) -> List[onnx.NodeProto]: + + const_nodes = [] + const_tensors = [x.name for x in m.graph.initializer] + const_tensors.extend([ + node.output[0] for node in m.graph.node if node.op_type == 'Constant' + ]) + # The output shape of some node types is determined by the input value + # we consider the output of this node doesn't have constant shape, + # so we do not simplify a such node even if the node is Shape op + dynamic_tensors = [] + + def is_dynamic(node): + if node.op_type in ['NonMaxSuppression', 'NonZero', 'Unique' + ] and node.input[0] not in const_tensors: + return True + if node.op_type in [ + 'Reshape', 'Expand', 'Upsample', 'ConstantOfShape' + ] and len(node.input) > 1 and node.input[1] not in const_tensors: + return True + if node.op_type in ['Resize'] and ( + (len(node.input) > 2 and node.input[2] not in const_tensors) or + (len(node.input) > 3 + and node.input[3] not in const_tensors)): # noqa: E129 + return True + return False + + for node in m.graph.node: + if any(x in dynamic_tensors for x in node.input): + dynamic_tensors.extend(node.output) + elif node.op_type == 'Shape': + const_nodes.append(node) + const_tensors.extend(node.output) + elif is_dynamic(node): + dynamic_tensors.extend(node.output) + elif all([x in const_tensors for x in node.input]): + const_nodes.append(node) + const_tensors.extend(node.output) + return copy.deepcopy(const_nodes) + + +def forward( + model, + inputs: Dict[str, np.ndarray] = None, + input_shapes: Optional[TensorShapes] = None) -> Dict[str, np.ndarray]: + if input_shapes is None: + input_shapes = {} + sess_options = rt.SessionOptions() + # load custom lib for onnxruntime in mmcv + ort_custom_op_path = None + try: + from mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() + except ImportError: + pass + if ort_custom_op_path is not None: + sess_options.register_custom_ops_library(ort_custom_op_path) + sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0) + sess_options.log_severity_level = 3 + sess = rt.InferenceSession( + model.SerializeToString(), + sess_options=sess_options, + providers=['CPUExecutionProvider']) + if inputs is None: + inputs = generate_rand_input(model, input_shapes=input_shapes) + outputs = [x.name for x in sess.get_outputs()] + run_options = rt.RunOptions() + run_options.log_severity_level = 3 + res = OrderedDict( + zip(outputs, sess.run(outputs, inputs, run_options=run_options))) + return res + + +def forward_for_node_outputs( + model: onnx.ModelProto, + nodes: List[onnx.NodeProto], + input_shapes: Optional[TensorShapes] = None, + inputs: Optional[Dict[str, + np.ndarray]] = None) -> Dict[str, np.ndarray]: + if input_shapes is None: + input_shapes = {} + model = copy.deepcopy(model) + add_features_to_output(model, nodes) + res = forward(model, inputs=inputs, input_shapes=input_shapes) + return res + + +def insert_elem(repeated_container, index: int, element): + repeated_container.extend([repeated_container[-1]]) + for i in reversed(range(index + 1, len(repeated_container) - 1)): + repeated_container[i].CopyFrom(repeated_container[i - 1]) + repeated_container[index].CopyFrom(element) + + +def eliminate_const_nodes(model: onnx.ModelProto, + const_nodes: List[onnx.NodeProto], + res: Dict[str, np.ndarray]) -> onnx.ModelProto: + """ + :param model: the original onnx model + :param const_nodes: const nodes detected by `get_constant_nodes` + :param res: The dict containing all tensors, got by `forward_all` + :return: the simplified onnx model. Redundant ops are all removed. + """ + for i, node in enumerate(model.graph.node): + if node in const_nodes: + for output in node.output: + new_node = copy.deepcopy(node) + new_node.name = 'node_' + output + new_node.op_type = 'Constant' + new_attr = onnx.helper.make_attribute( + 'value', + onnx.numpy_helper.from_array(res[output], name=output)) + del new_node.input[:] + del new_node.attribute[:] + del new_node.output[:] + new_node.output.extend([output]) + new_node.attribute.extend([new_attr]) + insert_elem(model.graph.node, i + 1, new_node) + del model.graph.node[i] + + return model + + +def optimize(model: onnx.ModelProto, skip_fuse_bn: bool, + skipped_optimizers: Optional[Sequence[str]]) -> onnx.ModelProto: + """ + :param model: The onnx model. + :return: The optimized onnx model. + Before simplifying, use this method to generate value_info, + which is used in `forward_all` + After simplifying, use this method to fold constants generated + in previous step into initializer, and eliminate unused constants. + """ + + # Due to a onnx bug, https://github.com/onnx/onnx/issues/2417, + # we need to add missing initializers into inputs + + onnx.checker.check_model(model) + input_num = len(model.graph.input) + model = add_initializers_into_inputs(model) + onnx.helper.strip_doc_string(model) + onnx.checker.check_model(model) + optimizers_list = [ + 'eliminate_deadend', 'eliminate_nop_dropout', 'eliminate_nop_cast', + 'eliminate_nop_monotone_argmax', 'eliminate_nop_pad', + 'extract_constant_to_initializer', 'eliminate_unused_initializer', + 'eliminate_nop_transpose', 'eliminate_identity', + 'fuse_add_bias_into_conv', 'fuse_consecutive_concats', + 'fuse_consecutive_log_softmax', 'fuse_consecutive_reduce_unsqueeze', + 'fuse_consecutive_squeezes', 'fuse_consecutive_transposes', + 'fuse_matmul_add_bias_into_gemm', 'fuse_pad_into_conv', + 'fuse_transpose_into_gemm' + ] + if not skip_fuse_bn: + optimizers_list.append('fuse_bn_into_conv') + if skipped_optimizers is not None: + for opt in skipped_optimizers: + try: + optimizers_list.remove(opt) + except ValueError: + pass + + model = onnxoptimizer.optimize(model, optimizers_list, fixed_point=True) + if model.ir_version > 3: + del model.graph.input[input_num:] + onnx.checker.check_model(model) + return model + + +def check(model_opt: onnx.ModelProto, + model_ori: onnx.ModelProto, + n_times: int = 5, + input_shapes: Optional[TensorShapes] = None, + inputs: Optional[List[Dict[str, np.ndarray]]] = None) -> bool: + """ + Warning: + Some models (e.g., MobileNet) may fail this check by a small magnitude. + Just ignore if it happens. + :param input_shapes: Shapes of generated random inputs + :param model_opt: The simplified ONNX model + :param model_ori: The original ONNX model + :param n_times: Generate n random inputs + """ + if input_shapes is None: + input_shapes = {} + onnx.checker.check_model(model_opt) + if inputs is not None: + n_times = min(n_times, len(inputs)) + for i in range(n_times): + print(f'Checking {i}/{n_times}...') + if inputs is None: + model_input = generate_rand_input( + model_opt, input_shapes=input_shapes) + else: + model_input = inputs[i] + res_opt = forward(model_opt, inputs=model_input) + res_ori = forward(model_ori, inputs=model_input) + + for name in res_opt.keys(): + if not np.allclose( + res_opt[name], res_ori[name], rtol=1e-4, atol=1e-5): + print( + 'Tensor {} changes after simplifying. The max diff is {}.'. + format(name, + np.max(np.abs(res_opt[name] - res_ori[name])))) + print('Note that the checking is not always correct.') + print('After simplifying:') + print(res_opt[name]) + print('Before simplifying:') + print(res_ori[name]) + print('----------------') + return False + return True + + +def clean_constant_nodes(const_nodes: List[onnx.NodeProto], + res: Dict[str, np.ndarray]): + """It seems not needed since commit 6f2a72, but maybe it still prevents + some unknown bug. + + :param const_nodes: const nodes detected by `get_constant_nodes` + :param res: The dict containing all tensors, got by `forward_all` + :return: The constant nodes which have an output in res + """ + return [node for node in const_nodes if node.output[0] in res] + + +def check_and_update_input_shapes(model: onnx.ModelProto, + input_shapes: TensorShapes) -> TensorShapes: + input_names = get_input_names(model) + if None in input_shapes: + if len(input_names) == 1: + input_shapes[input_names[0]] = input_shapes[None] + del input_shapes[None] + else: + raise RuntimeError('The model has more than 1 inputs!') + for x in input_shapes: + if x not in input_names: + raise RuntimeError(f'The model doesn\'t have input named "{x}"') + return input_shapes + + +def simplify(model: Union[str, onnx.ModelProto], + inputs: Sequence[Dict[str, np.ndarray]] = None, + output_file: str = None, + perform_optimization: bool = True, + skip_fuse_bn: bool = False, + skip_shape_inference=True, + input_shapes: Dict[str, Sequence[int]] = None, + skipped_optimizers: Sequence[str] = None) -> onnx.ModelProto: + """Simplify and optimize an onnx model. + + For models from detection and segmentation, it is strongly suggested to + input multiple input images for verification. + + Arguments: + model (str or onnx.ModelProto), path of model or loaded model object. + inputs (optional, Sequence[Dict[str, np.ndarray]]), inputs of model. + output_file (optional, str): output file to save simplified model. + perform_optimization (optional, bool), whether to perform optimization. + skip_fuse_bn (optional, bool): whether to skip fusing bn layer. + skip_shape_inference (optional, bool): whether to skip shape inference. + input_shapes (optional, Dict[str, Sequence[int]]): + the shapes of model inputs. + skipped_optimizers (optional, Sequence[str]): + the names of optimizer to be skipped. + + Returns: + onnx.ModelProto: simplified and optimized onnx model. + + Example: + >>> import onnx + >>> import numpy as np + >>> + >>> from mmcv.onnx import simplify + >>> + >>> input = np.random.randn(1, 3, 224, 224).astype(np.float32) + >>> input_file = 'sample.onnx' + >>> output_file = 'slim.onnx' + >>> model = simplify(input_file, [input], output_file) + """ + if input_shapes is None: + input_shapes = {} + if isinstance(model, str): + model = onnx.load(model) + # rename op with numeric name for issue + # https://github.com/onnx/onnx/issues/2613 + model = add_suffix2name(model) + onnx.checker.check_model(model) + model_ori = copy.deepcopy(model) + if not skip_shape_inference: + model = onnx.shape_inference.infer_shapes(model) + + input_shapes = check_and_update_input_shapes(model, input_shapes) + + if perform_optimization: + model = optimize(model, skip_fuse_bn, skipped_optimizers) + + const_nodes = get_constant_nodes(model) + feed_inputs = None if inputs is None else inputs[0] + res = forward_for_node_outputs( + model, const_nodes, input_shapes=input_shapes, inputs=feed_inputs) + const_nodes = clean_constant_nodes(const_nodes, res) + model = eliminate_const_nodes(model, const_nodes, res) + onnx.checker.check_model(model) + + if perform_optimization: + model = optimize(model, skip_fuse_bn, skipped_optimizers) + + check_ok = check( + model_ori, model, input_shapes=input_shapes, inputs=inputs) + + assert check_ok, 'Check failed for the simplified model!' + if output_file is not None: + save_dir, _ = os.path.split(output_file) + if save_dir: + os.makedirs(save_dir, exist_ok=True) + onnx.save(model, output_file) + return model diff --git a/setup.cfg b/setup.cfg index 3406d49daa..6546a993f2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmcv -known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf +known_third_party = addict,cv2,m2r,numpy,onnx,onnxoptimizer,onnxruntime,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY From 141293146be1abba18d88f62ed36c69a68333a73 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Fri, 18 Dec 2020 18:08:50 +0800 Subject: [PATCH 02/10] add simple doc --- docs/onnx.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 docs/onnx.md diff --git a/docs/onnx.md b/docs/onnx.md new file mode 100644 index 0000000000..4a4c10df9c --- /dev/null +++ b/docs/onnx.md @@ -0,0 +1,36 @@ +# Introduction of `onnx` module in MMCV (Experimental) + +## register_extra_symbolics + +Some extra symbolic functions need to be registered before exporting Pytorch model to ONNX. + +### Example + +```python +import mmcv +from mmcv.onnx import register_extra_symbolics + +opset_version = 11 +register_extra_symbolics(opset_version) +``` + +## ONNX simplify + +### Intention + +`mmcv.onnx.simplify` is based on [onnx-simplifier](https://github.com/daquexian/onnx-simplifier), which is a useful tool to make exported ONNX models slimmer by performing a series of optimization. However, for Pytorch models with custom op from `mmcv`, it would break down. Thus, custom op for ONNX Runtime should be registered. + +### Usage + +```python +import onnx +import numpy as np + +import mmcv +from mmcv.onnx import simplify + +input = np.random.randn(1, 3, 224, 224).astype(np.float32) +input_file = 'sample.onnx' +output_file = 'slim.onnx' +model = simplify(input_file, [input], output_file) +``` From 9853fe3ab6e1c60f39a9c39194e37498f6d6114f Mon Sep 17 00:00:00 2001 From: maningsheng Date: Fri, 18 Dec 2020 21:06:23 +0800 Subject: [PATCH 03/10] add unit test --- mmcv/onnx/simplify/core.py | 8 ++++---- tests/test_ops/test_onnx.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/mmcv/onnx/simplify/core.py b/mmcv/onnx/simplify/core.py index d4c5c36d65..8ab58fe353 100644 --- a/mmcv/onnx/simplify/core.py +++ b/mmcv/onnx/simplify/core.py @@ -374,7 +374,7 @@ def simplify(model: Union[str, onnx.ModelProto], output_file: str = None, perform_optimization: bool = True, skip_fuse_bn: bool = False, - skip_shape_inference=True, + skip_shape_inference: bool = True, input_shapes: Dict[str, Sequence[int]] = None, skipped_optimizers: Sequence[str] = None) -> onnx.ModelProto: """Simplify and optimize an onnx model. @@ -383,10 +383,10 @@ def simplify(model: Union[str, onnx.ModelProto], input multiple input images for verification. Arguments: - model (str or onnx.ModelProto), path of model or loaded model object. - inputs (optional, Sequence[Dict[str, np.ndarray]]), inputs of model. + model (str or onnx.ModelProto): path of model or loaded model object. + inputs (optional, Sequence[Dict[str, np.ndarray]]): inputs of model. output_file (optional, str): output file to save simplified model. - perform_optimization (optional, bool), whether to perform optimization. + perform_optimization (optional, bool): whether to perform optimization. skip_fuse_bn (optional, bool): whether to skip fusing bn layer. skip_shape_inference (optional, bool): whether to skip shape inference. input_shapes (optional, Dict[str, Sequence[int]]): diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index 7f43a06c86..a991e03ad4 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -4,6 +4,7 @@ import numpy as np import onnx import onnxruntime as rt +import pytest import torch import torch.nn as nn @@ -182,3 +183,26 @@ def warpped_function(torch_input, torch_rois): # allclose os.remove(onnx_file) assert np.allclose(pytorch_output, onnx_output, atol=1e-3) + + +def test_simplify(): + try: + from mmcv.onnx import simplify + except ImportError: + pytest.skip('No simplify found in mmcv.onnx') + + def foo(x): + y = x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2])) + return y + + net = WrapFunction(foo) + dummy_input = torch.randn(2, 3, 4, 5) + torch.onnx.export(net, dummy_input, onnx_file, input_names=['input']) + ori_onnx_model = onnx.load(onnx_file) + + feed_input = [{'input': dummy_input.detach().cpu().numpy()}] + slim_onnx_model = simplify(ori_onnx_model, feed_input, onnx_file) + numel_before = len(ori_onnx_model.graph.node) + numel_after = len(slim_onnx_model.graph.node) + assert numel_before == 18 and numel_after == 1, 'Simplify failed.' + os.remove(onnx_file) From f676a3675b2217fa2cc10c1af6612e4ff76d276b Mon Sep 17 00:00:00 2001 From: maningsheng Date: Tue, 22 Dec 2020 15:40:57 +0800 Subject: [PATCH 04/10] update docstring --- docs/onnx.md | 8 +- mmcv/onnx/simplify/core.py | 160 +++++++++++++++++++++++++++++-------- 2 files changed, 131 insertions(+), 37 deletions(-) diff --git a/docs/onnx.md b/docs/onnx.md index 4a4c10df9c..67dfe59074 100644 --- a/docs/onnx.md +++ b/docs/onnx.md @@ -28,9 +28,13 @@ import numpy as np import mmcv from mmcv.onnx import simplify - -input = np.random.randn(1, 3, 224, 224).astype(np.float32) +dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32) +input = {'input':dummy_input} input_file = 'sample.onnx' output_file = 'slim.onnx' model = simplify(input_file, [input], output_file) ``` + +### FAQs + +- None diff --git a/mmcv/onnx/simplify/core.py b/mmcv/onnx/simplify/core.py index 8ab58fe353..7be8707fd3 100644 --- a/mmcv/onnx/simplify/core.py +++ b/mmcv/onnx/simplify/core.py @@ -22,8 +22,10 @@ def add_features_to_output(m: onnx.ModelProto, nodes: List[onnx.NodeProto]) -> None: """Add features to output in pb, so that ONNX Runtime will output them. - :param m: the model that will be run in ONNX Runtime - :param nodes: nodes whose outputs will be added into the graph outputs + Args: + m (onnx.ModelProto): Input ONNX model. + nodes (List[onnx.NodeProto]): List of ONNX nodes, whose outputs + will be added into the graph output. """ for node in nodes: for output in node.output: @@ -52,9 +54,18 @@ def get_value_info_all(m: onnx.ModelProto, def get_shape(m: onnx.ModelProto, name: str) -> TensorShape: - """ - Note: This method relies on onnx shape inference, which is not reliable. - So only use it on input or output tensors + """Get shape info of a node in a model. + + Args: + m (onnx.ModelProto): Input model. + name (str): Name of a node. + + Returns: + TensorShape: Shape of a node. + + Note: + This method relies on onnx shape inference, which is not reliable. + So only use it on input or output tensors """ v = get_value_info_all(m, name) if v is not None: @@ -70,6 +81,14 @@ def get_elem_type(m: onnx.ModelProto, name: str) -> Optional[int]: def get_np_type_from_elem_type(elem_type: int) -> int: + """Map element type from ONNX to dtype of numpy. + + Args: + elem_type (int): Element type index in ONNX. + + Returns: + int: Data type in numpy. + """ # from https://github.com/onnx/onnx/blob/ # e5e9a539f550f07ec156812484e8d4f33fb91f88/onnx/onnx.proto#L461 sizes = (None, np.float32, np.uint8, np.int8, np.uint16, np.int16, @@ -82,6 +101,14 @@ def get_np_type_from_elem_type(elem_type: int) -> int: def get_input_names(model: onnx.ModelProto) -> List[str]: + """Get input names of a model. + + Args: + model (onnx.ModelProto): Input ONNX model. + + Returns: + List[str]: List of input names. + """ input_names = list( set([ipt.name for ipt in model.graph.input]) - set([x.name for x in model.graph.initializer])) @@ -89,6 +116,14 @@ def get_input_names(model: onnx.ModelProto) -> List[str]: def add_initializers_into_inputs(model: onnx.ModelProto) -> onnx.ModelProto: + """add initializers into inputs of a model. + + Args: + model (onnx.ModelProto): Input ONNX model. + + Returns: + onnx.ModelProto: Updated ONNX model. + """ for x in model.graph.initializer: input_names = [x.name for x in model.graph.input] if x.name not in input_names: @@ -106,7 +141,18 @@ def add_initializers_into_inputs(model: onnx.ModelProto) -> onnx.ModelProto: return model -def generate_rand_input(model, input_shapes: Optional[TensorShapes] = None): +def generate_rand_input( + model: onnx.ModelProto, + input_shapes: Optional[TensorShapes] = None) -> Dict[str, np.ndarray]: + """Generate random input for a model. + + Args: + model (onnx.ModelProto): Input ONNX model. + input_shapes (TensorShapes, optional): Input shapes of the model. + + Returns: + Dict[str, np.ndarray]: Generated inputs of `np.ndarray`. + """ if input_shapes is None: input_shapes = {} input_names = get_input_names(model) @@ -128,6 +174,14 @@ def generate_rand_input(model, input_shapes: Optional[TensorShapes] = None): def get_constant_nodes(m: onnx.ModelProto) -> List[onnx.NodeProto]: + """Collect constant nodes from a model. + + Args: + m (onnx.ModelProto): Input ONNX model. + + Returns: + List[onnx.NodeProto]: List of constant nodes. + """ const_nodes = [] const_tensors = [x.name for x in m.graph.initializer] @@ -169,9 +223,19 @@ def is_dynamic(node): def forward( - model, + model: onnx.ModelProto, inputs: Dict[str, np.ndarray] = None, input_shapes: Optional[TensorShapes] = None) -> Dict[str, np.ndarray]: + """Run forward on a model. + + Args: + model (onnx.ModelProto): Input ONNX model. + inputs (Dict[str, np.ndarray], optional): Inputs of the model. + input_shapes (TensorShapes, optional): Input shapes of the model. + + Returns: + Dict[str, np.ndarray]: Outputs of the model. + """ if input_shapes is None: input_shapes = {} sess_options = rt.SessionOptions() @@ -224,12 +288,18 @@ def insert_elem(repeated_container, index: int, element): def eliminate_const_nodes(model: onnx.ModelProto, const_nodes: List[onnx.NodeProto], res: Dict[str, np.ndarray]) -> onnx.ModelProto: + """Eliminate redundant constant nodes from model. + + Args: + model (onnx.ModelProto): The original ONNX model. + const_nodes (List[onnx.NodeProto]): + Constant nodes detected by `get_constant_nodes`. + res (Dict[str, np.ndarray]): Outputs of the model. + + Returns: + onnx.ModelProto: The simplified onnx model. """ - :param model: the original onnx model - :param const_nodes: const nodes detected by `get_constant_nodes` - :param res: The dict containing all tensors, got by `forward_all` - :return: the simplified onnx model. Redundant ops are all removed. - """ + for i, node in enumerate(model.graph.node): if node in const_nodes: for output in node.output: @@ -252,18 +322,21 @@ def eliminate_const_nodes(model: onnx.ModelProto, def optimize(model: onnx.ModelProto, skip_fuse_bn: bool, skipped_optimizers: Optional[Sequence[str]]) -> onnx.ModelProto: - """ - :param model: The onnx model. - :return: The optimized onnx model. - Before simplifying, use this method to generate value_info, - which is used in `forward_all` - After simplifying, use this method to fold constants generated - in previous step into initializer, and eliminate unused constants. - """ + """Perform optimization on an ONNX model. Before simplifying, use this + method to generate value_info. After simplifying, use this method to fold + constants generated in previous step into initializer, and eliminate unused + constants. + + Args: + model (onnx.ModelProto): The input ONNX model. + skip_fuse_bn (bool): Whether to skip fuse bn. + skipped_optimizers (Sequence[str]): List of optimizers to be skipped. + Returns: + onnx.ModelProto: The optimized model. + """ # Due to a onnx bug, https://github.com/onnx/onnx/issues/2417, # we need to add missing initializers into inputs - onnx.checker.check_model(model) input_num = len(model.graph.input) model = add_initializers_into_inputs(model) @@ -301,15 +374,20 @@ def check(model_opt: onnx.ModelProto, n_times: int = 5, input_shapes: Optional[TensorShapes] = None, inputs: Optional[List[Dict[str, np.ndarray]]] = None) -> bool: + """Check model before and after simplify. + + Args: + model_opt (onnx.ModelProto): Optimized model. + model_ori (onnx.ModelProto): Original model. + n_times (int, optional): Number of times to compare models. + input_shapes (TensorShapes, optional): Input shapes of the model. + inputs (List[Dict[str, np.ndarray]], optional): Inputs of the model. + + Returns: + bool: `True` means the outputs of two models have neglectable + numeric difference. """ - Warning: - Some models (e.g., MobileNet) may fail this check by a small magnitude. - Just ignore if it happens. - :param input_shapes: Shapes of generated random inputs - :param model_opt: The simplified ONNX model - :param model_ori: The original ONNX model - :param n_times: Generate n random inputs - """ + if input_shapes is None: input_shapes = {} onnx.checker.check_model(model_opt) @@ -344,13 +422,20 @@ def check(model_opt: onnx.ModelProto, def clean_constant_nodes(const_nodes: List[onnx.NodeProto], res: Dict[str, np.ndarray]): - """It seems not needed since commit 6f2a72, but maybe it still prevents - some unknown bug. + """Clean constant nodes. + + Args: + const_nodes (List[onnx.NodeProto]): List of constant nodes. + res (Dict[str, np.ndarray]): The forward result of model. + + Returns: + List[onnx.NodeProto]: The constant nodes which have an output in res. - :param const_nodes: const nodes detected by `get_constant_nodes` - :param res: The dict containing all tensors, got by `forward_all` - :return: The constant nodes which have an output in res + Notes: + It seems not needed since commit 6f2a72, but maybe it still prevents + some unknown bug. """ + return [node for node in const_nodes if node.output[0] in res] @@ -403,7 +488,8 @@ def simplify(model: Union[str, onnx.ModelProto], >>> >>> from mmcv.onnx import simplify >>> - >>> input = np.random.randn(1, 3, 224, 224).astype(np.float32) + >>> dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32) + >>> input = {'input':dummy_input} >>> input_file = 'sample.onnx' >>> output_file = 'slim.onnx' >>> model = simplify(input_file, [input], output_file) @@ -417,6 +503,7 @@ def simplify(model: Union[str, onnx.ModelProto], model = add_suffix2name(model) onnx.checker.check_model(model) model_ori = copy.deepcopy(model) + numel_node_ori = len(model_ori.graph.node) if not skip_shape_inference: model = onnx.shape_inference.infer_shapes(model) @@ -440,6 +527,9 @@ def simplify(model: Union[str, onnx.ModelProto], model_ori, model, input_shapes=input_shapes, inputs=inputs) assert check_ok, 'Check failed for the simplified model!' + numel_node_slim = len(model.graph.node) + print(f'Number of nodes: {numel_node_ori} -> {numel_node_slim}') + if output_file is not None: save_dir, _ = os.path.split(output_file) if save_dir: From 1d77760a021889364c89aede02590a36be12681d Mon Sep 17 00:00:00 2001 From: maningsheng Date: Wed, 30 Dec 2020 11:52:25 +0800 Subject: [PATCH 05/10] resolve some comment --- docs/onnx.md | 4 ++-- mmcv/onnx/simplify/core.py | 3 +-- tests/test_ops/test_onnx.py | 8 ++------ 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/docs/onnx.md b/docs/onnx.md index 67dfe59074..3a1d24f311 100644 --- a/docs/onnx.md +++ b/docs/onnx.md @@ -2,7 +2,7 @@ ## register_extra_symbolics -Some extra symbolic functions need to be registered before exporting Pytorch model to ONNX. +Some extra symbolic functions need to be registered before exporting PyTorch model to ONNX. ### Example @@ -18,7 +18,7 @@ register_extra_symbolics(opset_version) ### Intention -`mmcv.onnx.simplify` is based on [onnx-simplifier](https://github.com/daquexian/onnx-simplifier), which is a useful tool to make exported ONNX models slimmer by performing a series of optimization. However, for Pytorch models with custom op from `mmcv`, it would break down. Thus, custom op for ONNX Runtime should be registered. +`mmcv.onnx.simplify` is based on [onnx-simplifier](https://github.com/daquexian/onnx-simplifier), which is a useful tool to make exported ONNX models slimmer by performing a series of optimization. However, for Pytorch models with custom op from `mmcv`, it would break down. Thus, custom ops for ONNX Runtime should be registered. ### Usage diff --git a/mmcv/onnx/simplify/core.py b/mmcv/onnx/simplify/core.py index 7be8707fd3..dff32bb5f8 100644 --- a/mmcv/onnx/simplify/core.py +++ b/mmcv/onnx/simplify/core.py @@ -89,8 +89,7 @@ def get_np_type_from_elem_type(elem_type: int) -> int: Returns: int: Data type in numpy. """ - # from https://github.com/onnx/onnx/blob/ - # e5e9a539f550f07ec156812484e8d4f33fb91f88/onnx/onnx.proto#L461 + # from https://github.com/onnx/onnx/blob/e5e9a539f550f07ec156812484e8d4f33fb91f88/onnx/onnx.proto#L461 # noqa: E501 sizes = (None, np.float32, np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, str, np.bool, np.float16, np.double, np.uint32, np.uint64, np.complex64, np.complex128, np.float16) diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index a991e03ad4..5cd8e691e4 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -4,7 +4,6 @@ import numpy as np import onnx import onnxruntime as rt -import pytest import torch import torch.nn as nn @@ -186,10 +185,7 @@ def warpped_function(torch_input, torch_rois): def test_simplify(): - try: - from mmcv.onnx import simplify - except ImportError: - pytest.skip('No simplify found in mmcv.onnx') + from mmcv.onnx import simplify def foo(x): y = x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2])) @@ -204,5 +200,5 @@ def foo(x): slim_onnx_model = simplify(ori_onnx_model, feed_input, onnx_file) numel_before = len(ori_onnx_model.graph.node) numel_after = len(slim_onnx_model.graph.node) - assert numel_before == 18 and numel_after == 1, 'Simplify failed.' os.remove(onnx_file) + assert numel_before == 18 and numel_after == 1, 'Simplify failed.' From 9fd0897341344681d339afd8fa2d1d37b9f893bd Mon Sep 17 00:00:00 2001 From: maningsheng Date: Wed, 30 Dec 2020 13:09:41 +0800 Subject: [PATCH 06/10] add test dependency:onnxoptimizer --- requirements/test.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index e26c846137..d64268f2f8 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,6 +1,7 @@ coverage lmdb onnx==1.7.0 -onnxruntime==1.4.0 +onnxoptimizer +onnxruntime==1.6.0 pytest PyTurboJPEG From 0cabe295719f6c0a55d5eb0d6ef91727ccd68e79 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Wed, 30 Dec 2020 14:49:37 +0800 Subject: [PATCH 07/10] Fix onnxruntime register empty libpath --- mmcv/onnx/simplify/core.py | 4 ++-- requirements/test.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mmcv/onnx/simplify/core.py b/mmcv/onnx/simplify/core.py index dff32bb5f8..20428c81ca 100644 --- a/mmcv/onnx/simplify/core.py +++ b/mmcv/onnx/simplify/core.py @@ -239,13 +239,13 @@ def forward( input_shapes = {} sess_options = rt.SessionOptions() # load custom lib for onnxruntime in mmcv - ort_custom_op_path = None + ort_custom_op_path = '' try: from mmcv.ops import get_onnxruntime_op_path ort_custom_op_path = get_onnxruntime_op_path() except ImportError: pass - if ort_custom_op_path is not None: + if os.path.exists(ort_custom_op_path): sess_options.register_custom_ops_library(ort_custom_op_path) sess_options.graph_optimization_level = rt.GraphOptimizationLevel(0) sess_options.log_severity_level = 3 diff --git a/requirements/test.txt b/requirements/test.txt index d64268f2f8..2c8d01044d 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -2,6 +2,6 @@ coverage lmdb onnx==1.7.0 onnxoptimizer -onnxruntime==1.6.0 +onnxruntime==1.5.1 pytest PyTurboJPEG From d48d2fa30725cd95db580df27ceb0228ca461641 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Wed, 30 Dec 2020 15:56:13 +0800 Subject: [PATCH 08/10] test onnxruntime version --- requirements/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index 2c8d01044d..bec4fabd1c 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -2,6 +2,6 @@ coverage lmdb onnx==1.7.0 onnxoptimizer -onnxruntime==1.5.1 +onnxruntime==1.4.0 pytest PyTurboJPEG From 45519829a2ea9ba6465b3d78e7c8529ca18cea34 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Wed, 30 Dec 2020 16:41:47 +0800 Subject: [PATCH 09/10] set checker to false --- mmcv/onnx/simplify/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/onnx/simplify/common.py b/mmcv/onnx/simplify/common.py index 701e59aab1..6490bd643e 100644 --- a/mmcv/onnx/simplify/common.py +++ b/mmcv/onnx/simplify/common.py @@ -4,7 +4,7 @@ import onnx -def add_suffix2name(ori_model, suffix='__', verify=True): +def add_suffix2name(ori_model, suffix='__', verify=False): """Simplily add a suffix to the name of node, which has a numeric name.""" # check if has special op, which has subgraph. special_ops = ('If', 'Loop') From 5b1aca84dd5a004facb15ccdfda011f0dfe7f94e Mon Sep 17 00:00:00 2001 From: maningsheng Date: Thu, 31 Dec 2020 17:49:26 +0800 Subject: [PATCH 10/10] skip test_simplify for torch<1.5.0 --- setup.cfg | 2 +- tests/test_ops/test_onnx.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 6546a993f2..1ef231612f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmcv -known_third_party = addict,cv2,m2r,numpy,onnx,onnxoptimizer,onnxruntime,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf +known_third_party = addict,cv2,m2r,numpy,onnx,onnxoptimizer,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,torch,torchvision,yaml,yapf no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index 21da30c462..d70e2fb677 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -8,6 +8,7 @@ import pytest import torch import torch.nn as nn +from packaging import version onnx_file = 'tmp.onnx' @@ -65,7 +66,6 @@ def test_nms(): reason='CUDA is unavailable for test_softnms') def test_softnms(): from mmcv.ops import get_onnxruntime_op_path, soft_nms - from packaging import version # only support pytorch >= 1.7.0 if version.parse(torch.__version__) < version.parse('1.7.0'): @@ -275,6 +275,10 @@ def warpped_function(torch_input, torch_rois): def test_simplify(): from mmcv.onnx import simplify + # only support PyTorch >= 1.5.0 + if version.parse(torch.__version__) < version.parse('1.5.0'): + pytest.skip('mmcv.onnx.simplify only support with PyTorch >= 1.5.0') + def foo(x): y = x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2])) return y