-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
Copy pathdeploy_classification.py
306 lines (265 loc) · 11.9 KB
/
deploy_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Deploy Pretrained Vision Model from MxNet on VTA
================================================
**Author**: `Thierry Moreau <https://homes.cs.washington.edu/~moreau/>`_
This tutorial provides an end-to-end demo, on how to run ImageNet classification
inference onto the VTA accelerator design to perform ImageNet classification tasks.
It showcases Relay as a front end compiler that can perform quantization (VTA
only supports int8/32 inference) as well as graph packing (in order to enable
tensorization in the core) to massage the compute graph for the hardware target.
"""
######################################################################
# Install dependencies
# --------------------
# To use the autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user mxnet requests "Pillow<7"
#
# Now return to the python code. Import packages.
from __future__ import absolute_import, print_function
import argparse, json, os, requests, sys, time
from io import BytesIO
from os.path import join, isfile
from PIL import Image
from mxnet.gluon.model_zoo import vision
import numpy as np
from matplotlib import pyplot as plt
import tvm
from tvm import te
from tvm import rpc, autotvm, relay
from tvm.contrib import graph_executor, utils, download, graph_runtime
from tvm.contrib.debugger import debug_executor
from tvm.relay import transform
import vta
from vta.testing import simulator
from vta.top import graph_pack
# Make sure that TVM was compiled with RPC=1
assert tvm.runtime.enabled("rpc")
######################################################################
# Define the platform and model targets
# -------------------------------------
# Execute on CPU vs. VTA, and define the model.
# Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env()
# Set ``device=arm_cpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
target = env.target if device == "vta" else env.target_vta_cpu
# Dictionary lookup for when to start/end bit packing
pack_dict = {
"resnet18_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet34_v1": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet18_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet34_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet50_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
"resnet101_v2": ["nn.max_pool2d", "nn.global_avg_pool2d"],
}
# Name of Gluon model to compile
# The ``start_pack`` and ``stop_pack`` labels indicate where
# to start and end the graph packing relay pass: in other words
# where to start and finish offloading to VTA.
model = "resnet18_v1"
assert model in pack_dict
######################################################################
# Obtain an execution remote
# --------------------------
# When target is 'pynq', reconfigure FPGA and runtime.
# Otherwise, if target is 'sim', execute locally.
if env.TARGET not in ["sim", "tsim", "intelfocl"]:
# Get remote from tracker node if environment variable is set.
# To set up the tracker, you'll need to follow the "Auto-tuning
# a convolutional network for VTA" tutorial.
tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
# Otherwise if you have a device you want to program directly from
# the host, make sure you've set the variables below to the IP of
# your board.
device_host = os.environ.get("VTA_RPC_HOST", "192.168.2.99")
device_port = os.environ.get("VTA_RPC_PORT", "9091")
if not tracker_host or not tracker_port:
remote = rpc.connect(device_host, int(device_port))
else:
remote = autotvm.measure.request_remote(
env.TARGET, tracker_host, int(tracker_port), timeout=10000
)
# Reconfigure the JIT runtime and FPGA.
# You can program the FPGA with your own custom bitstream
# by passing the path to the bitstream file instead of None.
reconfig_start = time.time()
vta.reconfig_runtime(remote)
vta.program_fpga(remote, bitstream=None)
reconfig_time = time.time() - reconfig_start
print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
# In simulation mode, host the RPC server locally.
else:
remote = rpc.LocalSession()
if env.TARGET in ["intelfocl"]:
# program intelfocl aocx
vta.program_fpga(remote, bitstream="vta.bitstream")
# Get execution context from remote
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
######################################################################
# Build the inference graph executor
# ----------------------------------
# Grab vision model from Gluon model zoo and compile with Relay.
# The compilation steps are:
#
# 1. Front end translation from MxNet into Relay module.
# 2. Apply 8-bit quantization: here we skip the first conv layer,
# and dense layer which will both be executed in fp32 on the CPU.
# 3. Perform graph packing to alter the data layout for tensorization.
# 4. Perform constant folding to reduce number of operators (e.g. eliminate batch norm multiply).
# 5. Perform relay build to object file.
# 6. Load the object file onto remote (FPGA device).
# 7. Generate graph executor, `m`.
#
# Load pre-configured AutoTVM schedules
with autotvm.tophub.context(target):
# Populate the shape and data type dictionary for ImageNet classifier input
dtype_dict = {"data": "float32"}
shape_dict = {"data": (env.BATCH, 3, 224, 224)}
# Get off the shelf gluon model, and convert to relay
gluon_model = vision.get_model(model, pretrained=True)
# Measure build start time
build_start = time.time()
# Start front end compilation
mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
# Update shape and type dictionary
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
if target.device_name == "vta":
# Perform quantization in Relay
# Note: We set opt_level to 3 in order to fold batch norm
with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
mod = relay.quantize.quantize(mod, params=params)
# Perform graph packing and constant folding for VTA target
assert env.BLOCK_IN == env.BLOCK_OUT
# do device annotation if target is intelfocl or sim
relay_prog = graph_pack(
mod["main"],
env.BATCH,
env.BLOCK_OUT,
env.WGT_WIDTH,
start_name=pack_dict[model][0],
stop_name=pack_dict[model][1],
device_annot=(env.TARGET == "intelfocl" or env.TARGET == "sim"),
)
else:
relay_prog = mod["main"]
# Compile Relay program with AlterOpLayout disabled
if target.device_name != "vta":
with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}):
graph, lib, params = relay.build(
relay_prog, target=target, params=params, target_host=env.target_host
)
else:
if env.TARGET == "intelfocl" or env.TARGET == "sim":
# multiple targets to run both on cpu and vta
target = {"cpu": env.target_vta_cpu, "ext_dev": target}
with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
graph, lib, params = relay.build(
relay_prog, target=target, params=params, target_host=env.target_host
)
# Measure Relay build time
build_time = time.time() - build_start
print(model + " inference graph built in {0:.2f}s!".format(build_time))
# Send the inference library over to the remote RPC server
temp = utils.tempdir()
lib.export_library(temp.relpath("graphlib.tar"))
remote.upload(temp.relpath("graphlib.tar"))
lib = remote.load_module("graphlib.tar")
if env.TARGET == "intelfocl" or env.TARGET == "sim":
ctxes = [remote.ext_dev(0), remote.cpu(0)]
m = graph_runtime.create(graph, lib, ctxes)
else:
# Graph runtime
m = graph_runtime.create(graph, lib, ctx)
######################################################################
# Perform image classification inference
# --------------------------------------
# We run classification on an image sample from ImageNet
# We just need to download the categories files, `synset.txt`
# and an input test image.
# Download ImageNet categories
categ_url = "https://github.com/uwsampl/web-data/raw/main/vta/models/"
categ_fn = "synset.txt"
download.download(join(categ_url, categ_fn), categ_fn)
synset = eval(open(categ_fn).read())
# Download test image
image_url = "https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg"
image_fn = "cat.png"
download.download(image_url, image_fn)
# Prepare test image for inference
image = Image.open(image_fn).resize((224, 224))
plt.imshow(image)
plt.show()
image = np.array(image) - np.array([123.0, 117.0, 104.0])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
image = np.repeat(image, env.BATCH, axis=0)
# Set the network parameters and inputs
m.set_input(**params)
m.set_input("data", image)
# Perform inference and gather execution statistics
# More on: :py:method:`tvm.runtime.Module.time_evaluator`
num = 4 # number of times we run module for a single measurement
rep = 3 # number of measurements (we derive std dev from this)
timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
timer()
sim_stats = simulator.stats()
print("\nExecution statistics:")
for k, v in sim_stats.items():
# Since we execute the workload many times, we need to normalize stats
# Note that there is always one warm up run
# Therefore we divide the overall stats by (num * rep + 1)
print("\t{:<16}: {:>16}".format(k, v // (num * rep + 1)))
else:
tcost = timer()
std = np.std(tcost.results) * 1000
mean = tcost.mean * 1000
print("\nPerformed inference in %.2fms (std = %.2f) for %d samples" % (mean, std, env.BATCH))
print("Average per sample inference time: %.2fms" % (mean / env.BATCH))
# Get classification results
tvm_output = m.get_output(0, tvm.nd.empty((env.BATCH, 1000), "float32", remote.cpu(0)))
for b in range(env.BATCH):
top_categories = np.argsort(tvm_output.numpy()[b])
# Report top-5 classification results
print("\n{} prediction for sample {}".format(model, b))
print("\t#1:", synset[top_categories[-1]])
print("\t#2:", synset[top_categories[-2]])
print("\t#3:", synset[top_categories[-3]])
print("\t#4:", synset[top_categories[-4]])
print("\t#5:", synset[top_categories[-5]])
# This just checks that one of the 5 top categories
# is one variety of cat; this is by no means an accurate
# assessment of how quantization affects classification
# accuracy but is meant to catch changes to the
# quantization pass that would accuracy in the CI.
cat_detected = False
for k in top_categories[-5:]:
if "cat" in synset[k]:
cat_detected = True
assert cat_detected