Skip to content

Commit

Permalink
add model splitting support (open-mmlab#1)
Browse files Browse the repository at this point in the history
* add function marker and model extractor

* add fsaf split & partial mask rcnn split, import extract.py

* 1. add value renaming  2. add apply_marks in config to turn on/off marks

* rewind changes on pytorch2onnx.py

Co-authored-by: q.yao <streetyao@live.com>
  • Loading branch information
lzhangzz and grimoire authored Jun 28, 2021
1 parent 5998d24 commit ef41f69
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 4 deletions.
4 changes: 4 additions & 0 deletions configs/mmdet/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = ['./base.py', '../_base_/backends/tensorrt.py']

backend = 'default'
apply_marks = True
2 changes: 1 addition & 1 deletion mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def torch2onnx(img: Any,
keep_initializers_as_inputs=pytorch2onnx_cfg[
'keep_initializers_as_inputs'])

ret_value.value = 0
ret_value.value = 0
4 changes: 3 additions & 1 deletion mmdeploy/mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .anchor_head import AnchorHead
from .rpn_head import rpn_head_forward
from .fsaf_head import fsaf_head_forward

__all__ = ['AnchorHead']
__all__ = ['AnchorHead', 'rpn_head_forward', 'fsaf_head_forward']
7 changes: 7 additions & 0 deletions mmdeploy/mmdet/models/dense_heads/fsaf_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from mmdeploy.utils import FUNCTION_REWRITERS, mark


@FUNCTION_REWRITERS.register_rewriter('mmdet.models.FSAFHead.forward')
@mark('rpn_forward')
def fsaf_head_forward(rewriter, *args):
return rewriter.origin_func(*args)
7 changes: 7 additions & 0 deletions mmdeploy/mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from mmdeploy.utils import FUNCTION_REWRITERS, mark


@FUNCTION_REWRITERS.register_rewriter('mmdet.models.RPNHead.forward')
@mark('rpn_forward')
def rpn_head_forward(rewriter, self, feats):
return rewriter.origin_func(self, feats)
3 changes: 2 additions & 1 deletion mmdeploy/mmdet/models/detectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .single_stage import SingleStageDetector
from .two_stage import extract_feat

__all__ = ['SingleStageDetector']
__all__ = ['SingleStageDetector', 'extract_feat']
19 changes: 19 additions & 0 deletions mmdeploy/mmdet/models/detectors/two_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from mmdeploy.utils import FUNCTION_REWRITERS, mark
from mmdeploy.utils import SYMBOLICS_REGISTER
from mmcv.onnx.symbolic import grid_sampler


@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.extract_feat')
@mark('extract_feat')
def extract_feat(rewriter, self, img):
return rewriter.origin_func(self, img)


@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.forward')
def two_stage_forward(rewriter, self, img, *args):
return rewriter.origin_func(self, [img], img_metas=[[{}]], return_loss=False, *args)


@SYMBOLICS_REGISTER.register_symbolic('grid_sampler', is_pytorch=True)
def symbolic_grid_sample(symbolic_wrapper, *args):
return grid_sampler(*args)
3 changes: 2 additions & 1 deletion mmdeploy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .function_rewriter import FUNCTION_REWRITERS, RewriterContext
from .module_rewriter import MODULE_REWRITERS, patch_model
from .symbolic_register import SYMBOLICS_REGISTER, register_extra_symbolics
from .function_marker import mark

__all__ = [
'RewriterContext', 'FUNCTION_REWRITERS', 'MODULE_REWRITERS', 'patch_model',
'SYMBOLICS_REGISTER', 'register_extra_symbolics'
'SYMBOLICS_REGISTER', 'register_extra_symbolics', 'mark'
]
61 changes: 61 additions & 0 deletions mmdeploy/utils/function_marker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import inspect
import torch
from .function_rewriter import FUNCTION_REWRITERS


class Mark(torch.autograd.Function):
@staticmethod
def symbolic(g, x, type, name, id, attrs):
n = g.op("mmcv::Mark", x, type_s=type, name_s=name, id_i=id, **attrs)
return n

@staticmethod
def forward(ctx, x, *args):
return x


@FUNCTION_REWRITERS.register_rewriter("mmdeploy.utils.function_marker.Mark.symbolic")
def mark_symbolic(rewriter, g, x, *args):
if rewriter.cfg.get("apply_marks", False):
return rewriter.origin_func(g, x, *args)
return x


def mark_tensors(xs, type, name, attrs):
index = 0
visit = set()

def impl(ys, prefix):
nonlocal index
if isinstance(ys, torch.Tensor):
if ys not in visit:
visit.add(ys)
index += 1
return Mark.apply(ys, type, prefix, index - 1, attrs)
return ys
elif isinstance(ys, list):
return [impl(y, f'{prefix}/{i}') for i, y in enumerate(ys)]
elif isinstance(ys, tuple):
return tuple(impl(y, f'{prefix}/{i}') for i, y in enumerate(ys))
elif isinstance(ys, dict):
return {k: impl(v, f'{prefix}/{k}') for k, v in ys.items()}
return ys
return impl(xs, name)


def mark(func, **attrs):
attrs['func_s'] = func

def decorator(f):
params = inspect.signature(f).parameters.keys()
def g(*args, **kwargs):
if torch.onnx.is_in_onnx_export():
args = [mark_tensors(arg, 'input', name, attrs)
for name, arg in zip(params, args)]
rets = f(*args, **kwargs)
# TODO: maybe we can traverse the AST to get the retval names?
return mark_tensors(rets, 'output', func, attrs)
else:
return f(*args, **kwargs)
return g
return decorator
234 changes: 234 additions & 0 deletions tools/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import argparse
import os.path as osp
import onnx
import onnx.utils
import onnx.helper
from onnx import AttributeProto


def parse_args():
parser = argparse.ArgumentParser(
description='Extract model based on markers.')
parser.add_argument('input_model', help='Input ONNX model')
parser.add_argument('output_model', help='Output ONNX model')
parser.add_argument(
'--start', help='Start markers, format: func:type, e.g. backbone:input')
parser.add_argument('--end', help='End markers')

args = parser.parse_args()

args.start = args.start.split(',') if args.start else []
args.end = args.end.split(',') if args.end else []

return args


def remove_markers(model):
shortcut = []
success = True
while success:
success = False
for i, node in enumerate(model.graph.node):
if node.op_type == 'Mark':
for input in node.input:
shortcut.append((input, node.output))
del model.graph.node[i]
success = True
break
for src, dsts in shortcut:
for curr in model.graph.node:
for k, input in enumerate(curr.input):
if input in dsts:
curr.input[k] = src
# TODO: handle duplicated case?
for k, output in enumerate(model.graph.output):
print(output.name, dsts)
if output.name in dsts:
output.name = src
return model


def attribute_to_dict(attribute):
ret = {}
for a in attribute:
name = a.name
if a.type == AttributeProto.AttributeType.STRING:
ret[name] = str(a.s, 'utf-8')
elif a.type == AttributeProto.AttributeType.INT:
ret[name] = a.i
return ret


def _dfs_search_reacable_nodes_fast(self, node_output_name, graph_input_nodes, reachable_nodes):
outputs = {}
for index, node in enumerate(self.graph.node):
for name in node.output:
if name not in outputs:
outputs[name] = set()
outputs[name].add(index)

def impl(node_output_name, graph_input_nodes, reachable_nodes):
if node_output_name in graph_input_nodes:
return
if node_output_name not in outputs:
return
for index in outputs[node_output_name]:
node = self.graph.node[index]
if node in reachable_nodes:
continue
reachable_nodes.append(node)
for name in node.input:
impl(name, graph_input_nodes, reachable_nodes)
impl(node_output_name, graph_input_nodes, reachable_nodes)


def get_new_name(attrs):
if 'name' in attrs:
return attrs['name']
return '_'.join((attrs['func'], attrs['type'], str(attrs['id'])))


def rename_value(model, old_name, new_name):
for n in model.graph.node:
for i, output in enumerate(n.output):
if output == old_name:
n.output[i] = new_name
for i, input in enumerate(n.input):
if input == old_name:
n.input[i] = new_name
for v in model.graph.value_info:
if v.name == old_name:
v.name = new_name
for i, name in enumerate(model.graph.input):
if name == old_name:
model.graph.input[i] = new_name
for i, name in enumerate(model.graph.output):
if name == old_name:
model.graph.output[i] = new_name


def extract_model(model, start, end):
inputs = []
outputs = []
if not isinstance(start, (list, tuple)):
start = [start]
for s in start:
start_name, start_type = s.split(':')
assert start_type in ['input', 'output']
for node in model.graph.node:
if node.op_type == 'Mark':
attr = attribute_to_dict(node.attribute)
if attr['func'] == start_name and attr['type'] == start_type:
name = node.output[0] if start_type == 'input' else node.input[0]
if name not in inputs:
new_name = get_new_name(attr)
rename_value(model, name, new_name)
inputs.append(new_name)

print(f'inputs: {inputs}')

# collect outputs
# outputs = []
if not isinstance(end, (list, tuple)):
end = [end]
for e in end:
end_name, end_type = e.split(':')
assert end_type in ['input', 'output']
for node in model.graph.node:
if node.op_type == 'Mark':
attr = attribute_to_dict(node.attribute)
if attr['func'] == end_name and attr['type'] == end_type:
name = node.input[0] if end_type == 'input' else node.output[0]
if name not in outputs:
new_name = get_new_name(attr)
rename_value(model, name, new_name)
outputs.append(new_name)

print(f'outputs: {outputs}')

# replace Mark with Identity
for node in model.graph.node:
if node.op_type == 'Mark':
del node.attribute[:]
node.domain = ''
node.op_type = 'Identity'

# patch extractor
onnx.utils.Extractor._dfs_search_reachable_nodes = _dfs_search_reacable_nodes_fast

extractor = onnx.utils.Extractor(model)
extracted_model = extractor.extract_model(inputs, outputs)

# collect all used inputs
used = set()
for node in extracted_model.graph.node:
for input in node.input:
used.add(input)

for output in extracted_model.graph.output:
used.add(output.name)

# delete unused inputs
success = True
while success:
success = False
for i, input in enumerate(extracted_model.graph.input):
if input.name not in used:
del extracted_model.graph.input[i]
success = True
break

# eliminate output without shape
for xs in [extracted_model.graph.output]:
for x in xs:
if not x.type.tensor_type.shape.dim:
print(f'fixing output shape: {x.name}')
x.CopyFrom(onnx.helper.make_tensor_value_info(
x.name, x.type.tensor_type.elem_type, []))

# eliminate 0-batch dimension, dirty workaround for two-stage detectors
for input in extracted_model.graph.input:
if input.name in inputs:
if input.type.tensor_type.shape.dim[0].dim_value == 0:
input.type.tensor_type.shape.dim[0].dim_value = 1

# eliminate duplicated value_info for inputs
success = True
while success:
success = False
for i, x in enumerate(extracted_model.graph.value_info):
if x.name in inputs:
del extracted_model.graph.value_info[i]
success = True
break

return extracted_model


def collect_avaiable_marks(model):
marks = []
for node in model.graph.node:
if node.op_type == 'Mark':
attr = attribute_to_dict(node.attribute)
func = attr['func']
if func not in marks:
marks.append(func)
return marks


def main():
args = parse_args()

model = onnx.load(args.input_model)
marks = collect_avaiable_marks(model)
print("Available marks:\n {}".format('\n '.join(marks)))

extracted_model = extract_model(model, args.start, args.end)

if osp.splitext(args.output_model)[-1] != '.onnx':
args.output_model += '.onnx'
onnx.save(extracted_model, args.output_model)


if __name__ == '__main__':
main()

0 comments on commit ef41f69

Please sign in to comment.