-
Notifications
You must be signed in to change notification settings - Fork 6
/
get_flops.py
132 lines (110 loc) · 4.41 KB
/
get_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright (c) Open-CD. All rights reserved.
import argparse
import tempfile
from pathlib import Path
import torch
from mmengine import Config, DictAction
from mmengine.logging import MMLogger
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmseg.models import BaseSegmentor
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
import opencd_custom # noqa: F401,F403
try:
from mmengine.analysis import get_model_complexity_info
from mmengine.analysis.print_helper import _format_size
except ImportError:
raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.')
def parse_args():
parser = argparse.ArgumentParser(
description='Get the FLOPs of a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[512, 512],
help='input image size')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def inference(args: argparse.Namespace, logger: MMLogger) -> dict:
config_name = Path(args.config)
if not config_name.exists():
logger.error(f'Config file {config_name} does not exist')
cfg: Config = Config.fromfile(config_name)
cfg.work_dir = tempfile.TemporaryDirectory().name
cfg.log_level = 'WARN'
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope(cfg.get('scope', 'opencd'))
if len(args.shape) == 1:
input_shape = (6, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (6, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
result = {}
model: BaseSegmentor = MODELS.build(cfg.model)
if hasattr(model, 'auxiliary_head'):
model.auxiliary_head = None
if torch.cuda.is_available():
model.cuda()
model = revert_sync_batchnorm(model)
result['ori_shape'] = input_shape[-2:]
result['pad_shape'] = input_shape[-2:]
data_batch = {
'inputs': [torch.rand(input_shape)],
'data_samples': [SegDataSample(metainfo=result)]
}
data = model.data_preprocessor(data_batch)
model.eval()
if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']:
# TODO: Support MaskFormer and Mask2Former
raise NotImplementedError('MaskFormer and Mask2Former are not '
'supported yet.')
outputs = get_model_complexity_info(
model,
input_shape=None,
inputs=data['inputs'],
show_table=False,
show_arch=False)
pytorch_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
result['flops'] = _format_size(outputs['flops'], 6)
result['params'] = _format_size(outputs['params'], 6)
result['learnable_params'] = _format_size(pytorch_learnable_params, 6)
result['compute_type'] = 'direct: randomly generate a picture'
return result
def main():
args = parse_args()
logger = MMLogger.get_instance(name='MMLogger')
result = inference(args, logger)
split_line = '=' * 30
ori_shape = result['ori_shape']
pad_shape = result['pad_shape']
flops = result['flops']
params = result['params']
learnable_params = result['learnable_params']
compute_type = result['compute_type']
if pad_shape != ori_shape:
print(f'{split_line}\nUse size divisor set input shape '
f'from {ori_shape} to {pad_shape}')
print(f'{split_line}\nCompute type: {compute_type}\n'
f'Input shape: {pad_shape}\nFlops: {flops}\n'
f'Params: {params}\n'
f'Learnable Params: {learnable_params}\n{split_line} \n')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify '
'that the flops computation is correct.')
if __name__ == '__main__':
main()