forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_visual_engine.py
208 lines (168 loc) · 7.76 KB
/
build_visual_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import argparse
import os
import shutil
from time import time
import tensorrt as trt
import torch
from PIL import Image
from transformers import (AutoProcessor, Blip2ForConditionalGeneration,
Blip2Processor, LlavaForConditionalGeneration,
NougatProcessor, VisionEncoderDecoderModel)
def export_visual_wrapper_onnx(visual_wrapper, image, output_dir):
logger.log(trt.Logger.INFO, "Exporting onnx")
os.mkdir(f'{output_dir}/onnx')
torch.onnx.export(visual_wrapper,
image,
f'{output_dir}/onnx/visual_encoder.onnx',
opset_version=17,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {
0: 'batch'
}})
def build_trt_engine(img_height, img_width, output_dir, max_batch_size):
part_name = 'visual_encoder'
onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name)
engine_file = '%s/%s_fp16.engine' % (output_dir, part_name)
logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
profile = builder.create_optimization_profile()
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)
parser = trt.OnnxParser(network, logger)
with open(onnx_file, 'rb') as model:
if not parser.parse(model.read(), "/".join(onnx_file.split("/"))):
logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file)
for error in range(parser.num_errors):
logger.log(trt.Logger.ERROR, parser.get_error(error))
logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file)
# Delete onnx files since we don't need them now
shutil.rmtree(f'{output_dir}/onnx')
nBS = -1
nMinBS = 1
nOptBS = max(nMinBS, int(max_batch_size / 2))
nMaxBS = max_batch_size
logger.log(trt.Logger.INFO,
f"Processed image dims {img_height}x{img_width}")
H, W = img_height, img_width
inputT = network.get_input(0)
inputT.shape = [nBS, 3, H, W]
profile.set_shape(inputT.name, [nMinBS, 3, H, W], [nOptBS, 3, H, W],
[nMaxBS, 3, H, W])
config.add_optimization_profile(profile)
t0 = time()
engine_string = builder.build_serialized_network(network, config)
t1 = time()
if engine_string is None:
raise RuntimeError("Failed building %s" % (engine_file))
else:
logger.log(trt.Logger.INFO,
"Succeeded building %s in %d s" % (engine_file, t1 - t0))
with open(engine_file, 'wb') as f:
f.write(engine_string)
def build_blip2_engine(args):
model_type = 'Salesforce/blip2-' + args.model_name
processor = Blip2Processor.from_pretrained(model_type)
model = Blip2ForConditionalGeneration.from_pretrained(
model_type, torch_dtype=torch.float16)
model.to(args.device)
raw_image = Image.new('RGB', [10, 10]) # dummy image
prompt = "Question: what is this? Answer:"
inputs = processor(raw_image, prompt,
return_tensors="pt").to(args.device, torch.float16)
image = inputs['pixel_values']
class Blip2VisionWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.vision_model = model.vision_model
self.qformer = model.qformer
self.projector = model.language_projection
self.query_tokens = model.query_tokens
def forward(self, image):
features = self.vision_model(image)[0]
qformer_output = self.qformer(query_embeds=self.query_tokens,
encoder_hidden_states=features,
return_dict=True)
return self.projector(qformer_output.last_hidden_state)
wrapper = Blip2VisionWrapper(model)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(image.shape[2], image.shape[3], args.output_dir,
args.max_batch_size)
def build_llava_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
raw_image = Image.new('RGB', [10, 10]) # dummy image
image = processor(text="dummy", images=raw_image,
return_tensors="pt")['pixel_values'].to(
args.device, torch.float16)
class LlavaVisionWrapper(torch.nn.Module):
def __init__(self, tower, projector, feature_layer):
super().__init__()
self.tower = tower
self.projector = projector
self.feature_layer = feature_layer
def forward(self, image):
all_hidden_states = self.tower(
image, output_hidden_states=True).hidden_states
features = all_hidden_states[self.feature_layer][:, 1:]
return self.projector(features)
model = LlavaForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=torch.float16)
model.to(args.device)
wrapper = LlavaVisionWrapper(model.vision_tower,
model.multi_modal_projector,
model.config.vision_feature_layer)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(image.shape[2], image.shape[3], args.output_dir,
args.max_batch_size)
def build_nougat_engine(args):
processor = NougatProcessor.from_pretrained(args.model_path)
raw_image = Image.new('RGB', [10, 10]) # dummy image
image = processor(raw_image, return_tensors="pt")['pixel_values'].to(
args.device, torch.float16)
class SwinEncoderWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, image):
return self.encoder(image).last_hidden_state
model = VisionEncoderDecoderModel.from_pretrained(args.model_path,
torch_dtype=torch.float16)
swin_encoder = model.get_encoder().to(args.device)
wrapper = SwinEncoderWrapper(swin_encoder)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(image.shape[2], image.shape[3], args.output_dir,
args.max_batch_size)
if __name__ == '__main__':
logger = trt.Logger(trt.Logger.ERROR)
parser = argparse.ArgumentParser()
parser.add_argument('--model_name',
type=str,
default=None,
help="Model name")
parser.add_argument('--model_path',
type=str,
default=None,
help="Huggingface repo or local directory with weights")
parser.add_argument('--output_dir',
type=str,
default='visual_engines',
help="Directory where visual TRT engines are saved")
parser.add_argument('--max_batch_size',
type=int,
default=4,
help="Maximum batch size for input images")
args = parser.parse_args()
args.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
args.output_dir = args.output_dir + "/" + args.model_name
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.model_name in ['opt-2.7b', 'flan-t5-xl']:
build_blip2_engine(args)
elif 'llava' in args.model_name:
build_llava_engine(args)
elif 'nougat' in args.model_name:
build_nougat_engine(args)
else:
raise RuntimeError(f"Invalid model name {args.model_name}")