forked from OpenGVLab/InternVL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
export.py
128 lines (94 loc) · 3.29 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# --------------------------------------------------------
# InternVL
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import argparse
import os
import time
import torch
from config import get_config
from models import build_model
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str,
default='internimage_t_1k_224')
parser.add_argument('--ckpt_dir', type=str,
default='/mnt/petrelfs/share_data/huangzhenhang/code/internimage/checkpoint_dir/new/cls')
parser.add_argument('--onnx', default=False, action='store_true')
parser.add_argument('--trt', default=False, action='store_true')
args = parser.parse_args()
args.cfg = os.path.join('./configs', f'{args.model_name}.yaml')
args.ckpt = os.path.join(args.ckpt_dir, f'{args.model_name}.pth')
args.size = int(args.model_name.split('.')[0].split('_')[-1])
cfg = get_config(args)
return args, cfg
def get_model(args, cfg):
model = build_model(cfg)
ckpt = torch.load(args.ckpt, map_location='cpu')['model']
model.load_state_dict(ckpt)
return model
def speed_test(model, input):
# warm-up
for _ in tqdm(range(100)):
_ = model(input)
# speed test
torch.cuda.synchronize()
start = time.time()
for _ in tqdm(range(100)):
_ = model(input)
end = time.time()
th = 100 / (end - start)
print(f'using time: {end - start}, throughput {th}')
def torch2onnx(args, cfg):
model = get_model(args, cfg).cuda()
# speed_test(model)
onnx_name = f'{args.model_name}.onnx'
torch.onnx.export(model,
torch.rand(1, 3, args.size, args.size).cuda(),
onnx_name,
input_names=['input'],
output_names=['output'])
return model
def onnx2trt(args):
from mmdeploy.backend.tensorrt import from_onnx
onnx_name = f'{args.model_name}.onnx'
from_onnx(
onnx_name,
args.model_name,
dict(
input=dict(
min_shape=[1, 3, args.size, args.size],
opt_shape=[1, 3, args.size, args.size],
max_shape=[1, 3, args.size, args.size],
)
),
max_workspace_size=2 ** 30,
)
def check(args, cfg):
from mmdeploy.backend.tensorrt.wrapper import TRTWrapper
model = get_model(args, cfg).cuda()
model.eval()
trt_model = TRTWrapper(f'{args.model_name}.engine',
['output'])
x = torch.randn(1, 3, args.size, args.size).cuda()
torch_out = model(x)
trt_out = trt_model(dict(input=x))['output']
print('torch out shape:', torch_out.shape)
print('trt out shape:', trt_out.shape)
print('max delta:', (torch_out - trt_out).abs().max())
print('mean delta:', (torch_out - trt_out).abs().mean())
speed_test(model, x)
speed_test(trt_model, dict(input=x))
def main():
args, cfg = get_args()
if args.onnx or args.trt:
torch2onnx(args, cfg)
print('torch -> onnx: succeess')
if args.trt:
onnx2trt(args)
print('onnx -> trt: success')
check(args, cfg)
if __name__ == '__main__':
main()