Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add onnx simplify #751

Merged
merged 11 commits into from
Dec 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/onnx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Introduction of `onnx` module in MMCV (Experimental)

## register_extra_symbolics

Some extra symbolic functions need to be registered before exporting PyTorch model to ONNX.

### Example

```python
import mmcv
from mmcv.onnx import register_extra_symbolics

opset_version = 11
register_extra_symbolics(opset_version)
```

## ONNX simplify

### Intention

`mmcv.onnx.simplify` is based on [onnx-simplifier](https://github.com/daquexian/onnx-simplifier), which is a useful tool to make exported ONNX models slimmer by performing a series of optimization. However, for Pytorch models with custom op from `mmcv`, it would break down. Thus, custom ops for ONNX Runtime should be registered.

### Usage

```python
import onnx
import numpy as np

import mmcv
from mmcv.onnx import simplify
dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
input = {'input':dummy_input}
input_file = 'sample.onnx'
output_file = 'slim.onnx'
model = simplify(input_file, [input], output_file)
```

### FAQs

- None
3 changes: 2 additions & 1 deletion mmcv/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .simplify import simplify
from .symbolic import register_extra_symbolics

__all__ = ['register_extra_symbolics']
__all__ = ['register_extra_symbolics', 'simplify']
3 changes: 3 additions & 0 deletions mmcv/onnx/simplify/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .core import simplify

__all__ = ['simplify']
43 changes: 43 additions & 0 deletions mmcv/onnx/simplify/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import copy
import warnings

import onnx


def add_suffix2name(ori_model, suffix='__', verify=False):
"""Simplily add a suffix to the name of node, which has a numeric name."""
# check if has special op, which has subgraph.
special_ops = ('If', 'Loop')
for node in ori_model.graph.node:
if node.op_type in special_ops:
warnings.warn(f'This model has special op: {node.op_type}.')
return ori_model

model = copy.deepcopy(ori_model)

def need_update(name):
return name.isnumeric()

def update_name(nodes):
for node in nodes:
if need_update(node.name):
node.name += suffix

update_name(model.graph.initializer)
update_name(model.graph.input)
update_name(model.graph.output)

for i, node in enumerate(ori_model.graph.node):
# process input of node
for j, name in enumerate(node.input):
if need_update(name):
model.graph.node[i].input[j] = name + suffix

# process output of node
for j, name in enumerate(node.output):
if need_update(name):
model.graph.node[i].output[j] = name + suffix
if verify:
onnx.checker.check_model(model)

return model
Loading