-
Notifications
You must be signed in to change notification settings - Fork 179
/
app.py
137 lines (118 loc) · 4.78 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import time
from absl import app, logging
import cv2
import numpy as np
import tensorflow as tf
from yolov3_tf2.models import (
YoloV3, YoloV3Tiny
)
from yolov3_tf2.dataset import transform_images, load_tfrecord_dataset
from yolov3_tf2.utils import draw_outputs
from flask import Flask, request, Response, jsonify, send_from_directory, abort
import os
# customize your API through the following parameters
classes_path = './data/labels/coco.names'
weights_path = './weights/yolov3.tf'
tiny = False # set to True if using a Yolov3 Tiny model
size = 416 # size images are resized to for model
output_path = './detections/' # path to output folder where images with detections are saved
num_classes = 80 # number of classes in model
# load in weights and classes
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
if tiny:
yolo = YoloV3Tiny(classes=num_classes)
else:
yolo = YoloV3(classes=num_classes)
yolo.load_weights(weights_path).expect_partial()
print('weights loaded')
class_names = [c.strip() for c in open(classes_path).readlines()]
print('classes loaded')
# Initialize Flask application
app = Flask(__name__)
# API that returns JSON with classes found in images
@app.route('/detections', methods=['POST'])
def get_detections():
raw_images = []
images = request.files.getlist("images")
image_names = []
for image in images:
image_name = image.filename
image_names.append(image_name)
image.save(os.path.join(os.getcwd(), image_name))
img_raw = tf.image.decode_image(
open(image_name, 'rb').read(), channels=3)
raw_images.append(img_raw)
num = 0
# create list for final response
response = []
for j in range(len(raw_images)):
# create list of responses for current image
responses = []
raw_img = raw_images[j]
num+=1
img = tf.expand_dims(raw_img, 0)
img = transform_images(img, size)
t1 = time.time()
boxes, scores, classes, nums = yolo(img)
t2 = time.time()
print('time: {}'.format(t2 - t1))
print('detections:')
for i in range(nums[0]):
print('\t{}, {}, {}'.format(class_names[int(classes[0][i])],
np.array(scores[0][i]),
np.array(boxes[0][i])))
responses.append({
"class": class_names[int(classes[0][i])],
"confidence": float("{0:.2f}".format(np.array(scores[0][i])*100))
})
response.append({
"image": image_names[j],
"detections": responses
})
img = cv2.cvtColor(raw_img.numpy(), cv2.COLOR_RGB2BGR)
img = draw_outputs(img, (boxes, scores, classes, nums), class_names)
cv2.imwrite(output_path + 'detection' + str(num) + '.jpg', img)
print('output saved to: {}'.format(output_path + 'detection' + str(num) + '.jpg'))
#remove temporary images
for name in image_names:
os.remove(name)
try:
return jsonify({"response":response}), 200
except FileNotFoundError:
abort(404)
# API that returns image with detections on it
@app.route('/image', methods= ['POST'])
def get_image():
image = request.files["images"]
image_name = image.filename
image.save(os.path.join(os.getcwd(), image_name))
img_raw = tf.image.decode_image(
open(image_name, 'rb').read(), channels=3)
img = tf.expand_dims(img_raw, 0)
img = transform_images(img, size)
t1 = time.time()
boxes, scores, classes, nums = yolo(img)
t2 = time.time()
print('time: {}'.format(t2 - t1))
print('detections:')
for i in range(nums[0]):
print('\t{}, {}, {}'.format(class_names[int(classes[0][i])],
np.array(scores[0][i]),
np.array(boxes[0][i])))
img = cv2.cvtColor(img_raw.numpy(), cv2.COLOR_RGB2BGR)
img = draw_outputs(img, (boxes, scores, classes, nums), class_names)
cv2.imwrite(output_path + 'detection.jpg', img)
print('output saved to: {}'.format(output_path + 'detection.jpg'))
# prepare image for response
_, img_encoded = cv2.imencode('.png', img)
response = img_encoded.tostring()
#remove temporary image
os.remove(image_name)
try:
return Response(response=response, status=200, mimetype='image/png')
except FileNotFoundError:
abort(404)
if __name__ == '__main__':
app.run(debug=True, host = '0.0.0.0', port=5000)