forked from open-mmlab/mmdetection3d
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add model splitting support (open-mmlab#1)
* add function marker and model extractor * add fsaf split & partial mask rcnn split, import extract.py * 1. add value renaming 2. add apply_marks in config to turn on/off marks * rewind changes on pytorch2onnx.py Co-authored-by: q.yao <streetyao@live.com>
- Loading branch information
Showing
10 changed files
with
340 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
_base_ = ['./base.py', '../_base_/backends/tensorrt.py'] | ||
|
||
backend = 'default' | ||
apply_marks = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from .anchor_head import AnchorHead | ||
from .rpn_head import rpn_head_forward | ||
from .fsaf_head import fsaf_head_forward | ||
|
||
__all__ = ['AnchorHead'] | ||
__all__ = ['AnchorHead', 'rpn_head_forward', 'fsaf_head_forward'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from mmdeploy.utils import FUNCTION_REWRITERS, mark | ||
|
||
|
||
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.FSAFHead.forward') | ||
@mark('rpn_forward') | ||
def fsaf_head_forward(rewriter, *args): | ||
return rewriter.origin_func(*args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from mmdeploy.utils import FUNCTION_REWRITERS, mark | ||
|
||
|
||
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.RPNHead.forward') | ||
@mark('rpn_forward') | ||
def rpn_head_forward(rewriter, self, feats): | ||
return rewriter.origin_func(self, feats) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .single_stage import SingleStageDetector | ||
from .two_stage import extract_feat | ||
|
||
__all__ = ['SingleStageDetector'] | ||
__all__ = ['SingleStageDetector', 'extract_feat'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from mmdeploy.utils import FUNCTION_REWRITERS, mark | ||
from mmdeploy.utils import SYMBOLICS_REGISTER | ||
from mmcv.onnx.symbolic import grid_sampler | ||
|
||
|
||
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.extract_feat') | ||
@mark('extract_feat') | ||
def extract_feat(rewriter, self, img): | ||
return rewriter.origin_func(self, img) | ||
|
||
|
||
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.forward') | ||
def two_stage_forward(rewriter, self, img, *args): | ||
return rewriter.origin_func(self, [img], img_metas=[[{}]], return_loss=False, *args) | ||
|
||
|
||
@SYMBOLICS_REGISTER.register_symbolic('grid_sampler', is_pytorch=True) | ||
def symbolic_grid_sample(symbolic_wrapper, *args): | ||
return grid_sampler(*args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from .function_rewriter import FUNCTION_REWRITERS, RewriterContext | ||
from .module_rewriter import MODULE_REWRITERS, patch_model | ||
from .symbolic_register import SYMBOLICS_REGISTER, register_extra_symbolics | ||
from .function_marker import mark | ||
|
||
__all__ = [ | ||
'RewriterContext', 'FUNCTION_REWRITERS', 'MODULE_REWRITERS', 'patch_model', | ||
'SYMBOLICS_REGISTER', 'register_extra_symbolics' | ||
'SYMBOLICS_REGISTER', 'register_extra_symbolics', 'mark' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import inspect | ||
import torch | ||
from .function_rewriter import FUNCTION_REWRITERS | ||
|
||
|
||
class Mark(torch.autograd.Function): | ||
@staticmethod | ||
def symbolic(g, x, type, name, id, attrs): | ||
n = g.op("mmcv::Mark", x, type_s=type, name_s=name, id_i=id, **attrs) | ||
return n | ||
|
||
@staticmethod | ||
def forward(ctx, x, *args): | ||
return x | ||
|
||
|
||
@FUNCTION_REWRITERS.register_rewriter("mmdeploy.utils.function_marker.Mark.symbolic") | ||
def mark_symbolic(rewriter, g, x, *args): | ||
if rewriter.cfg.get("apply_marks", False): | ||
return rewriter.origin_func(g, x, *args) | ||
return x | ||
|
||
|
||
def mark_tensors(xs, type, name, attrs): | ||
index = 0 | ||
visit = set() | ||
|
||
def impl(ys, prefix): | ||
nonlocal index | ||
if isinstance(ys, torch.Tensor): | ||
if ys not in visit: | ||
visit.add(ys) | ||
index += 1 | ||
return Mark.apply(ys, type, prefix, index - 1, attrs) | ||
return ys | ||
elif isinstance(ys, list): | ||
return [impl(y, f'{prefix}/{i}') for i, y in enumerate(ys)] | ||
elif isinstance(ys, tuple): | ||
return tuple(impl(y, f'{prefix}/{i}') for i, y in enumerate(ys)) | ||
elif isinstance(ys, dict): | ||
return {k: impl(v, f'{prefix}/{k}') for k, v in ys.items()} | ||
return ys | ||
return impl(xs, name) | ||
|
||
|
||
def mark(func, **attrs): | ||
attrs['func_s'] = func | ||
|
||
def decorator(f): | ||
params = inspect.signature(f).parameters.keys() | ||
def g(*args, **kwargs): | ||
if torch.onnx.is_in_onnx_export(): | ||
args = [mark_tensors(arg, 'input', name, attrs) | ||
for name, arg in zip(params, args)] | ||
rets = f(*args, **kwargs) | ||
# TODO: maybe we can traverse the AST to get the retval names? | ||
return mark_tensors(rets, 'output', func, attrs) | ||
else: | ||
return f(*args, **kwargs) | ||
return g | ||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
import argparse | ||
import os.path as osp | ||
import onnx | ||
import onnx.utils | ||
import onnx.helper | ||
from onnx import AttributeProto | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Extract model based on markers.') | ||
parser.add_argument('input_model', help='Input ONNX model') | ||
parser.add_argument('output_model', help='Output ONNX model') | ||
parser.add_argument( | ||
'--start', help='Start markers, format: func:type, e.g. backbone:input') | ||
parser.add_argument('--end', help='End markers') | ||
|
||
args = parser.parse_args() | ||
|
||
args.start = args.start.split(',') if args.start else [] | ||
args.end = args.end.split(',') if args.end else [] | ||
|
||
return args | ||
|
||
|
||
def remove_markers(model): | ||
shortcut = [] | ||
success = True | ||
while success: | ||
success = False | ||
for i, node in enumerate(model.graph.node): | ||
if node.op_type == 'Mark': | ||
for input in node.input: | ||
shortcut.append((input, node.output)) | ||
del model.graph.node[i] | ||
success = True | ||
break | ||
for src, dsts in shortcut: | ||
for curr in model.graph.node: | ||
for k, input in enumerate(curr.input): | ||
if input in dsts: | ||
curr.input[k] = src | ||
# TODO: handle duplicated case? | ||
for k, output in enumerate(model.graph.output): | ||
print(output.name, dsts) | ||
if output.name in dsts: | ||
output.name = src | ||
return model | ||
|
||
|
||
def attribute_to_dict(attribute): | ||
ret = {} | ||
for a in attribute: | ||
name = a.name | ||
if a.type == AttributeProto.AttributeType.STRING: | ||
ret[name] = str(a.s, 'utf-8') | ||
elif a.type == AttributeProto.AttributeType.INT: | ||
ret[name] = a.i | ||
return ret | ||
|
||
|
||
def _dfs_search_reacable_nodes_fast(self, node_output_name, graph_input_nodes, reachable_nodes): | ||
outputs = {} | ||
for index, node in enumerate(self.graph.node): | ||
for name in node.output: | ||
if name not in outputs: | ||
outputs[name] = set() | ||
outputs[name].add(index) | ||
|
||
def impl(node_output_name, graph_input_nodes, reachable_nodes): | ||
if node_output_name in graph_input_nodes: | ||
return | ||
if node_output_name not in outputs: | ||
return | ||
for index in outputs[node_output_name]: | ||
node = self.graph.node[index] | ||
if node in reachable_nodes: | ||
continue | ||
reachable_nodes.append(node) | ||
for name in node.input: | ||
impl(name, graph_input_nodes, reachable_nodes) | ||
impl(node_output_name, graph_input_nodes, reachable_nodes) | ||
|
||
|
||
def get_new_name(attrs): | ||
if 'name' in attrs: | ||
return attrs['name'] | ||
return '_'.join((attrs['func'], attrs['type'], str(attrs['id']))) | ||
|
||
|
||
def rename_value(model, old_name, new_name): | ||
for n in model.graph.node: | ||
for i, output in enumerate(n.output): | ||
if output == old_name: | ||
n.output[i] = new_name | ||
for i, input in enumerate(n.input): | ||
if input == old_name: | ||
n.input[i] = new_name | ||
for v in model.graph.value_info: | ||
if v.name == old_name: | ||
v.name = new_name | ||
for i, name in enumerate(model.graph.input): | ||
if name == old_name: | ||
model.graph.input[i] = new_name | ||
for i, name in enumerate(model.graph.output): | ||
if name == old_name: | ||
model.graph.output[i] = new_name | ||
|
||
|
||
def extract_model(model, start, end): | ||
inputs = [] | ||
outputs = [] | ||
if not isinstance(start, (list, tuple)): | ||
start = [start] | ||
for s in start: | ||
start_name, start_type = s.split(':') | ||
assert start_type in ['input', 'output'] | ||
for node in model.graph.node: | ||
if node.op_type == 'Mark': | ||
attr = attribute_to_dict(node.attribute) | ||
if attr['func'] == start_name and attr['type'] == start_type: | ||
name = node.output[0] if start_type == 'input' else node.input[0] | ||
if name not in inputs: | ||
new_name = get_new_name(attr) | ||
rename_value(model, name, new_name) | ||
inputs.append(new_name) | ||
|
||
print(f'inputs: {inputs}') | ||
|
||
# collect outputs | ||
# outputs = [] | ||
if not isinstance(end, (list, tuple)): | ||
end = [end] | ||
for e in end: | ||
end_name, end_type = e.split(':') | ||
assert end_type in ['input', 'output'] | ||
for node in model.graph.node: | ||
if node.op_type == 'Mark': | ||
attr = attribute_to_dict(node.attribute) | ||
if attr['func'] == end_name and attr['type'] == end_type: | ||
name = node.input[0] if end_type == 'input' else node.output[0] | ||
if name not in outputs: | ||
new_name = get_new_name(attr) | ||
rename_value(model, name, new_name) | ||
outputs.append(new_name) | ||
|
||
print(f'outputs: {outputs}') | ||
|
||
# replace Mark with Identity | ||
for node in model.graph.node: | ||
if node.op_type == 'Mark': | ||
del node.attribute[:] | ||
node.domain = '' | ||
node.op_type = 'Identity' | ||
|
||
# patch extractor | ||
onnx.utils.Extractor._dfs_search_reachable_nodes = _dfs_search_reacable_nodes_fast | ||
|
||
extractor = onnx.utils.Extractor(model) | ||
extracted_model = extractor.extract_model(inputs, outputs) | ||
|
||
# collect all used inputs | ||
used = set() | ||
for node in extracted_model.graph.node: | ||
for input in node.input: | ||
used.add(input) | ||
|
||
for output in extracted_model.graph.output: | ||
used.add(output.name) | ||
|
||
# delete unused inputs | ||
success = True | ||
while success: | ||
success = False | ||
for i, input in enumerate(extracted_model.graph.input): | ||
if input.name not in used: | ||
del extracted_model.graph.input[i] | ||
success = True | ||
break | ||
|
||
# eliminate output without shape | ||
for xs in [extracted_model.graph.output]: | ||
for x in xs: | ||
if not x.type.tensor_type.shape.dim: | ||
print(f'fixing output shape: {x.name}') | ||
x.CopyFrom(onnx.helper.make_tensor_value_info( | ||
x.name, x.type.tensor_type.elem_type, [])) | ||
|
||
# eliminate 0-batch dimension, dirty workaround for two-stage detectors | ||
for input in extracted_model.graph.input: | ||
if input.name in inputs: | ||
if input.type.tensor_type.shape.dim[0].dim_value == 0: | ||
input.type.tensor_type.shape.dim[0].dim_value = 1 | ||
|
||
# eliminate duplicated value_info for inputs | ||
success = True | ||
while success: | ||
success = False | ||
for i, x in enumerate(extracted_model.graph.value_info): | ||
if x.name in inputs: | ||
del extracted_model.graph.value_info[i] | ||
success = True | ||
break | ||
|
||
return extracted_model | ||
|
||
|
||
def collect_avaiable_marks(model): | ||
marks = [] | ||
for node in model.graph.node: | ||
if node.op_type == 'Mark': | ||
attr = attribute_to_dict(node.attribute) | ||
func = attr['func'] | ||
if func not in marks: | ||
marks.append(func) | ||
return marks | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
model = onnx.load(args.input_model) | ||
marks = collect_avaiable_marks(model) | ||
print("Available marks:\n {}".format('\n '.join(marks))) | ||
|
||
extracted_model = extract_model(model, args.start, args.end) | ||
|
||
if osp.splitext(args.output_model)[-1] != '.onnx': | ||
args.output_model += '.onnx' | ||
onnx.save(extracted_model, args.output_model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |