-
Notifications
You must be signed in to change notification settings - Fork 34
/
build.py
61 lines (47 loc) · 1.95 KB
/
build.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import imp
import os
import subprocess
from string import Template
PLUGINS = [
'interpolate',
]
BASE_FOLDER = 'torch2trt_dynamic/converters'
NINJA_TEMPLATE = Template((
'rule link\n'
' command = g++ -shared -o $$out $$in -L$torch_dir/lib -L$cuda_dir/lib64 -L$trt_lib_dir -lc10 -lc10_cuda -ltorch -lcudart -lprotobuf -lprotobuf-lite -pthread -lpthread -lnvinfer\n' # noqa: E501
'rule protoc\n'
' command = protoc $$in --cpp_out=. --python_out=.\n'
'rule cxx\n'
' command = g++ -c -fPIC $$in -I$cuda_dir/include -I$torch_dir/include -I$torch_dir/include/torch/csrc/api/include -I. -std=c++11 -I$trt_inc_dir\n' # noqa: E501
))
PLUGIN_TEMPLATE = Template((
'build $plugin_dir/$plugin.pb.h $plugin_dir/$plugin.pb.cc $plugin_dir/${plugin}_pb2.py: protoc $plugin_dir/$plugin.proto\n' # noqa: E501
'build $plugin.pb.o: cxx $plugin_dir/$plugin.pb.cc\n'
'build $plugin.o: cxx $plugin_dir/$plugin.cpp\n'))
def build(cuda_dir='/usr/local/cuda',
torch_dir=imp.find_module('torch')[1],
trt_inc_dir='/usr/include/aarch64-linux-gnu',
trt_lib_dir='/usr/lib/aarch64-linux-gnu'):
global PLUGINS, BASE_FOLDER, NINJA_TEMPLATE, PLUGIN_TEMPLATE
NINJA_STR = NINJA_TEMPLATE.substitute({
'torch_dir': torch_dir,
'cuda_dir': cuda_dir,
'trt_inc_dir': trt_inc_dir,
'trt_lib_dir': trt_lib_dir,
})
plugin_o_files = []
for plugin in PLUGINS:
NINJA_STR += \
PLUGIN_TEMPLATE.substitute({
'plugin': plugin,
'plugin_dir': os.path.join(BASE_FOLDER, plugin),
})
plugin_o_files += [plugin + '.pb.o', plugin + '.o']
NINJA_STR += Template(
('build torch2trt_dynamic/libtorch2trt_dynamic.so: link $o_files\n'
)).substitute({'o_files': ' '.join(plugin_o_files)})
with open('build.ninja', 'w') as f:
f.write(NINJA_STR)
subprocess.call(['ninja'])
if __name__ == '__main__':
build()