-
Notifications
You must be signed in to change notification settings - Fork 0
/
object_classification_yolov8.py
350 lines (304 loc) · 16.8 KB
/
object_classification_yolov8.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# This script is used to classify objects in a video using the YOLOv8 model.
# The script reads a video from a message queue, classifies the objects in the video, and writes the annotated video to a message queue.
# It saves the detected objects in a json file and the annotated video locally.
# For this it uses the ultralytics package to perform object detection and tracking.
# Local imports
from utils.ReturnObject import ReturnJSON
from utils.TranslateObject import translate
from utils.VariableClass import VariableClass
from utils.ColorDetector import FindObjectColors
from utils.ClassificationObject import ClassificationObject
from utils.AnnotateFrame import annotate_frame, annotate_bbox_frame
from utils.ClassificationObjectFunctions import create_classification_object, edit_classification_object, find_classification_object
# External imports
import os
import cv2
import time
import json
import torch
import numpy as np
from ultralytics import YOLO
from uugai_python_dynamic_queue.MessageBrokers import RabbitMQ
from uugai_python_kerberos_vault.KerberosVault import KerberosVault
# Following error is thrown: [W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.
# https://stackoverflow.com/questions/69711410/could-not-initialize-nnpack
# torch.backends.nnpack.enabled = False
# Initialize the VariableClass object, which contains all the necessary environment variables.
var = VariableClass()
# Initialize a message broker using the python_queue_reader package
if var.LOGGING:
print('a) Initializing RabbitMQ')
rabbitmq = RabbitMQ(
queue_name=var.QUEUE_NAME,
target_queue_name=var.TARGET_QUEUE_NAME,
exchange=var.QUEUE_EXCHANGE,
host=var.QUEUE_HOST,
username=var.QUEUE_USERNAME,
password=var.QUEUE_PASSWORD)
# Initialize Kerberos Vault
if var.LOGGING:
print('b) Initializing Kerberos Vault')
kerberos_vault = KerberosVault(
storage_uri=var.STORAGE_URI,
storage_access_key=var.STORAGE_ACCESS_KEY,
storage_secret_key=var.STORAGE_SECRET_KEY)
while True:
# Receive message from the queue, and retrieve the media from the Kerberos Vault utilizing the message information.
if var.LOGGING:
print('1) Receiving message from RabbitMQ')
message = rabbitmq.receive_message()
if message == []:
if var.LOGGING:
print('No message received, waiting for 3 seconds')
time.sleep(3)
continue
if var.LOGGING:
print('2) Retrieving media from Kerberos Vault')
resp = kerberos_vault.retrieve_media(
message=message,
media_type='video',
media_savepath=var.MEDIA_SAVEPATH)
if var.TIME_VERBOSE:
start_time = time.time()
total_time_preprocessing = 0
total_time_class_prediction = 0
total_time_color_prediction = 0
total_time_processing = 0
total_time_postprocessing = 0
start_time_preprocessing = time.time()
# Perform object classification on the media
# initialise the yolo model, additionally use the device parameter to specify the device to run the model on.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL = YOLO(var.MODEL_NAME).to(device)
if var.LOGGING:
print(f'3) Using device: {device}')
# Open video-capture/recording using the video-path. Throw FileNotFoundError if cap is unable to open.
if var.LOGGING:
print(f'4) Opening video file: {var.MEDIA_SAVEPATH}')
cap = cv2.VideoCapture(var.MEDIA_SAVEPATH)
if not cap.isOpened():
FileNotFoundError('Unable to open video file')
# Initialize the video-writer if the SAVE_VIDEO is set to True.
if var.SAVE_VIDEO:
fourcc = cv2.VideoWriter.fourcc(*'avc1')
video_out = cv2.VideoWriter(
filename=var.OUTPUT_MEDIA_SAVEPATH,
fourcc=fourcc,
fps=var.CLASSIFICATION_FPS,
frameSize=(int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
)
if var.FIND_DOMINANT_COLORS:
color_detector = FindObjectColors(
downsample_factor=0.7,
min_clusters=var.MIN_CLUSTERS,
max_clusters=var.MAX_CLUSTERS,
)
# Initialize the classification process.
# 2 lists are initialized:
# Classification objects
# Additional list for easy access to the ids.
classification_object_list: list[ClassificationObject] = []
classification_object_ids: list[int] = []
# frame_number -> The current frame number. Depending on the frame_skip_factor this can make jumps.
# predicted_frames -> The number of frames, that were used for the prediction. This goes up by one each prediction iteration.
# frame_skip_factor is the factor by which the input video frames are skipped.
frame_number, predicted_frames = 0, 0
frame_skip_factor = int(cap.get(cv2.CAP_PROP_FPS) / var.CLASSIFICATION_FPS)
# Loop over the video frames, and perform object classification.
# The classification process is done until the counter reaches the MAX_NUMBER_OF_PREDICTIONS or the last frame is reached.
MAX_FRAME_NUMBER = cap.get(cv2.CAP_PROP_FRAME_COUNT)
if var.LOGGING:
print(f'5) Classifying frames')
if var.TIME_VERBOSE:
total_time_preprocessing += time.time() - start_time_preprocessing
start_time_processing = time.time()
while (predicted_frames < var.MAX_NUMBER_OF_PREDICTIONS) and (frame_number < MAX_FRAME_NUMBER):
# Read the frame from the video-capture.
success, frame = cap.read()
if not success:
break
# Keep the first frame in memory, if the CREATE_BBOX_FRAME is set to True.
# This is used to draw the tracking results on.
if var.CREATE_BBOX_FRAME and frame_number == 0:
bbox_frame = frame.copy()
# Check if the frame_number corresponds to a frame that should be classified.
if frame_number % frame_skip_factor == 0:
# Perform object classification on the frame.
# persist=True -> The tracking results are stored in the model.
# persist should be kept True, as this provides unique IDs for each detection.
# More information about the tracking results via https://docs.ultralytics.com/reference/engine/results/
if var.TIME_VERBOSE:
start_time_class_prediction = time.time()
results = MODEL.track(
source=frame,
persist=True,
verbose=False,
conf=var.CLASSIFICATION_THRESHOLD,
classes=var.ALLOWED_CLASSIFICATIONS)
if var.TIME_VERBOSE:
total_time_class_prediction += time.time() - start_time_class_prediction
# Check if the results are not None,
# Otherwise, the postprocessing should not be done.
# Iterate over the detected objects and their masks.
if results is not None:
# Loop over boxes and masks.
# If no masks are found, meaning the model used is not a segmentation model, the mask is set to None.
for box, mask in zip(results[0].boxes, results[0].masks or [None] * len(results[0].boxes)):
# Check if object are detected.
# If no object is detected, the box.id will be None.
# In this case, the inner-loop is broken. Not calling the object related functions.
if box.id is None:
break
# Extract the object's id, name, confidence, and trajectory.
# Also include the mask, if a segmentation model was used. Otherwise, the mask is set to None.
# The crop_and_detect function will use trajectory instead if no mask is provided.
object_id = int(box.id)
object_name = translate(results[0].names[int(box.cls)])
object_conf = float(box.conf)
object_trajectory = box.xyxy.tolist()[0]
object_mask = np.int32(
mask.xy[0].tolist()) if mask is not None else None
# Check if the id is already in the classification_object_ids list.
# If it is, edit the classification object.
# Otherwise, create a new classification object.
if object_id in classification_object_ids:
classification_object = find_classification_object(
classification_object_list, object_id)
# Calculate the dominant colors of the object if the FIND_DOMINANT_COLORS parameter is set to True.
# And the object has been detected a multiple of COLOR_PREDICTION_INTERVAL times.
if var.FIND_DOMINANT_COLORS and classification_object.occurences % var.COLOR_PREDICTION_INTERVAL == 0:
if var.TIME_VERBOSE:
start_time_color_prediction = time.time()
main_colors_bgr, main_colors_hls, main_colors_str = color_detector.crop_and_detect(
frame=frame,
trajectory=object_trajectory,
mask_polygon=object_mask)
if var.TIME_VERBOSE:
total_time_color_prediction += time.time() - start_time_color_prediction
else:
main_colors_bgr, main_colors_hls, main_colors_str = None, None, None
edit_classification_object(
id=object_id,
object_name=object_name,
object_conf=object_conf,
trajectory=object_trajectory,
frame_number=frame_number,
classification_object_list=classification_object_list,
colors_bgr=main_colors_bgr,
colors_hls=main_colors_hls,
colors_str=main_colors_str)
else:
# Calculate the dominant colors of the object if the FIND_DOMINANT_COLORS parameter is set to True.
if var.FIND_DOMINANT_COLORS:
if var.TIME_VERBOSE:
start_time_color_prediction = time.time()
main_colors_bgr, main_colors_hls, main_colors_str = color_detector.crop_and_detect(
frame=frame,
trajectory=object_trajectory,
mask_polygon=object_mask)
if var.TIME_VERBOSE:
total_time_color_prediction += time.time() - start_time_color_prediction
else:
main_colors_bgr, main_colors_hls, main_colors_str = None, None, None
classification_object = create_classification_object(
id=object_id,
first_object_name=object_name,
first_object_conf=object_conf,
first_trajectory=object_trajectory,
first_frame=frame_number,
frame_width=int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
frame_height=int(
cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
first_colors_bgr=main_colors_bgr,
first_colors_hls=main_colors_hls,
first_colors_str=main_colors_str)
classification_object_ids.append(object_id)
classification_object_list.append(
classification_object)
# Depending on the SAVE_VIDEO or PLOT parameter, the frame is annotated.
# This is done using a custom annotation function.
if var.SAVE_VIDEO or var.PLOT:
annotated_frame = annotate_frame(
frame=frame,
frame_number=frame_number,
classification_object_list=classification_object_list,
min_distance=var.MIN_DISTANCE,
min_detections=var.MIN_DETECTIONS)
# Show the annotated frame if the PLOT parameter is set to True.
cv2.imshow("YOLOv8 Tracking",
annotated_frame) if var.PLOT else None
cv2.waitKey(1) if var.PLOT else None
# Write the annotated frame to the video-writer if the SAVE_VIDEO parameter is set to True.
video_out.write(annotated_frame) if var.SAVE_VIDEO else None
# Increase the frame_number and predicted_frames by one.
predicted_frames += 1
frame_number += 1
if var.TIME_VERBOSE:
total_time_processing += time.time() - start_time_processing
start_time_postprocessing = time.time()
# Depending on the CREATE_BBOX_FRAME parameter, the bbox_frame is annotated.
# This is done using a custom annotation function.
if var.CREATE_BBOX_FRAME:
if var.LOGGING:
print('6) Annotating bbox frame')
bbox_frame = annotate_bbox_frame(
bbox_frame=bbox_frame,
classification_object_list=classification_object_list)
# Depending on the CREATE_RETURN_JSON parameter, the detected objects are saved in a json file.
# Initialize the ReturnJSON object.
# This creates a json object with the correct structure.
if var.CREATE_RETURN_JSON:
if var.LOGGING:
print('7) Creating ReturnJSON object')
return_json = ReturnJSON()
# Depending on the user preference, the detected objects are filtered.
# In this case, the objects are filtered based on the MIN_DETECTIONS parameters.
filtered_classification_object_list = []
for classification_object in classification_object_list:
if classification_object.occurences >= var.MIN_DETECTIONS:
filtered_classification_object_list.append(
classification_object)
return_json.add_detected_object(classification_object)
if var.LOGGING:
print(f"\t - {len(classification_object_list)} objects where detected. Of which {len(filtered_classification_object_list)} objects where detected more than {var.MIN_DETECTIONS} times.")
# Depending on the SAVE_RETURN_JSON parameter, the return_json object is saved locally.
return_json.save_returnjson(
var.RETURN_JSON_SAVEPATH) if var.SAVE_RETURN_JSON else None
# Depending on the SAVE_BBOX_FRAME parameter, the bbox_frame is saved locally.
cv2.imwrite(var.BBOX_FRAME_SAVEPATH,
bbox_frame) if var.SAVE_BBOX_FRAME else None
if var.TIME_VERBOSE:
total_time_postprocessing += time.time() - start_time_postprocessing
# Depending on the TARGET_QUEUE_NAME parameter, the resulting JSON-object is sent to the target queue.
# This is done by adding the data to the original message.
if var.TARGET_QUEUE_NAME != "":
message['operation'] = return_json.return_object['operation']
message['data'] = return_json.return_object['data']
return_message = json.dumps(message)
rabbitmq.send_message(return_message)
# Depending on the TIME_VERBOSE parameter, the time it took to classify the objects is printed.
if var.TIME_VERBOSE:
print(
f'\t - Classification took: {round(time.time() - start_time, 1)} seconds, @ {var.CLASSIFICATION_FPS} fps.')
print(
f'\t\t - {round(total_time_preprocessing, 2)}s for preprocessing and initialisation')
print(
f'\t\t - {round(total_time_processing, 2)}s for processing of which:')
print(
f'\t\t\t - {round(total_time_class_prediction, 2)}s for class prediction')
print(
f'\t\t\t - {round(total_time_color_prediction, 2)}s for color prediction')
print(
f'\t\t\t - {round(total_time_processing - total_time_class_prediction - total_time_color_prediction, 2)}s for other processing')
print(
f'\t\t - {round(total_time_postprocessing, 2)}s for postprocessing')
print(f'\t - Original video: {round(cap.get(cv2.CAP_PROP_FRAME_COUNT)/cap.get(cv2.CAP_PROP_FPS), 1)} seconds, @ {round(cap.get(cv2.CAP_PROP_FPS), 1)} fps @ {int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))}x{int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))}. File size of {round(os.path.getsize(var.MEDIA_SAVEPATH)/1024**2, 1)} MB')
# If the videowriter was active, the videowriter is released.
# Close the video-capture and destroy all windows.
if var.LOGGING:
print('8) Releasing video writer and closing video capture')
print("\n\n")
video_out.release() if var.SAVE_VIDEO else None
cap.release()
cv2.destroyAllWindows()