forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 1
/
build_vit_qformer.py
90 lines (71 loc) · 2.73 KB
/
build_vit_qformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import sys
from time import time
import tensorrt as trt
iModelID = int(
sys.argv[1]) if len(sys.argv) > 1 and sys.argv[1].isdigit() else -1
onnxFileList = [
'onnx/visual_encoder/visual_encoder.onnx', 'onnx/Qformer/Qformer.onnx'
]
planFileList = [
'plan/visual_encoder/visual_encoder_fp16.plan',
'plan/Qformer/Qformer_fp16.plan'
]
os.system('mkdir -p ./plan/visual_encoder')
os.system('mkdir -p ./plan/Qformer')
logger = trt.Logger(trt.Logger.ERROR)
def build(iPart, minBS=1, optBS=2, maxBS=4):
onnxFile = onnxFileList[iPart]
planFile = planFileList[iPart]
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
profile = builder.create_optimization_profile()
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)
parser = trt.OnnxParser(network, logger)
with open(onnxFile, 'rb') as model:
if not parser.parse(model.read(), "/".join(onnxFile.split("/"))):
print("Failed parsing %s" % onnxFile)
for error in range(parser.num_errors):
print(parser.get_error(error))
print("Succeeded parsing %s" % onnxFile)
nBS = -1
nMinBS = minBS
nOptBS = optBS
nMaxBS = maxBS
if iPart == 0:
inputT = network.get_input(0)
inputT.shape = [nBS, 3, 224, 224]
profile.set_shape(inputT.name, [nMinBS, 3, 224, 224],
[nOptBS, 3, 224, 224], [nMaxBS, 3, 224, 224])
elif iPart == 1:
inputT = network.get_input(0)
inputT.shape = [nBS, 32, 768]
profile.set_shape(inputT.name, [nMinBS, 32, 768], [nOptBS, 32, 768],
[nMaxBS, 32, 768])
inputT = network.get_input(1)
inputT.shape = [nBS, 257, 1408]
profile.set_shape(inputT.name, [nMinBS, 257, 1408], [nOptBS, 257, 1408],
[nMaxBS, 257, 1408])
inputT = network.get_input(2)
inputT.shape = [nBS, 257]
profile.set_shape(inputT.name, [nMinBS, 257], [nOptBS, 257],
[nMaxBS, 257])
else:
raise RuntimeError("iPart should be either 0 (ViT) or 1 (Qformer)")
config.add_optimization_profile(profile)
t0 = time()
engineString = builder.build_serialized_network(network, config)
t1 = time()
if engineString == None:
print("Failed building %s" % planFile)
else:
print("Succeeded building %s in %d s" % (planFile, t1 - t0))
with open(planFile, 'wb') as f:
f.write(engineString)
if __name__ == "__main__":
if iModelID != 0 and iModelID != 1:
print("Error model number, should be in [0, 1]")
exit()
build(iModelID, minBS=1, optBS=2, maxBS=4)