-
Notifications
You must be signed in to change notification settings - Fork 3
/
run.py
424 lines (398 loc) · 18.6 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
sys.path.insert(0, "modules/object_tracking/yolov5")
import argparse
import numpy as np
import cv2
import torch
from pathlib import Path
import json
from tqdm import tqdm
from collections import deque, defaultdict
from common.config_parser import get_config
from common.data_io import VideoDatasetLoader
from common.transforms import STTranTransform, YOLOv5Transform
from common.inference_utils import fill_sttran_entry_inference
from common.image_processing import convert_annotation_frame_to_video
from common.model_utils import (
bbox_pair_generation,
concat_separated_head,
construct_sliding_window,
generate_sliding_window_mask,
)
from common.plot import draw_bboxes
from common.metrics_utils import generate_triplets_scores
from modules.object_tracking import HeadTracking, ObjectTracking
from modules.gaze_following import GazeFollowing
from modules.gaze_following.head_association import assign_human_head_frame
from modules.sthoip_transformer.sttran_gaze import STTranGazeCrossAttention
from modules.object_tracking import FeatureExtractionResNet101
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument("--source", type=str, default="", help="path to source")
parser.add_argument("--future", type=int, default=0, help="seconds in future")
parser.add_argument("--cfg", type=str, default="configs", help="path to configs")
parser.add_argument("--weights", type=str, default="weights", help="root folder for all pretrained weights")
parser.add_argument("--imgsz", type=int, default=640, help="train, val image size (pixels)")
parser.add_argument("--hoi-thres", type=float, default=0.25, help="threshold for HOI score")
parser.add_argument("--out", type=str, default="output", help="output folder")
parser.add_argument("--print", action="store_true", help="print HOIs")
opt = parser.parse_args()
return opt
@torch.no_grad()
def main(opt):
source = Path(opt.source)
if not source.exists():
print(f"{source} does not exist, exit")
return -1
video_name = source.stem
future = opt.future
hoi_thres = opt.hoi_thres
print_hois = opt.print
cfg_path = Path(opt.cfg)
# path for all model weights
weight_path = Path(opt.weights)
yolo_weight_path = weight_path / "yolov5"
deepsort_weight_path = weight_path / "deep_sort"
gaze_following_weight_path = weight_path / "detecting_attended" / "model_videoatttarget.pt"
backbone_model_path = weight_path / "backbone"
sttran_word_vector_dir = weight_path / "semantic"
model_path = weight_path / "sttrangaze" / f"f{future}_final.pt"
# output files
out = opt.out
output_folder = Path(out) / video_name
if not output_folder.exists():
output_folder.mkdir()
trace_file = output_folder / f"{video_name}_trace.json"
gaze_file = output_folder / f"{video_name}_gaze.json"
hoi_file = output_folder / f"{video_name}_hoi.txt"
result_file = output_folder / f"{video_name}_result.json"
result_video_file = output_folder / f"{video_name}_result.mp4"
# model params
imgsz = opt.imgsz
gaze = "cross"
global_token = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"
cfg = get_config(str(cfg_path / "final" / f"eval_hyp_f{future}.yaml"))
sampling_mode = cfg["sampling_mode"]
dim_gaze_heatmap = cfg["dim_gaze_heatmap"] # 64x64 always, dont care in test
dim_transformer_ffn = cfg["dim_transformer_ffn"]
sttran_enc_layer_num = cfg["sttran_enc_layer_num"]
sttran_dec_layer_num = cfg["sttran_dec_layer_num"]
sttran_sliding_window = cfg["sttran_sliding_window"]
separate_head = cfg["separate_head"] # always separate, dont care in test
loss_type = cfg["loss_type"] # only focal, dont care in test
mlp_projection = cfg["mlp_projection"] # MLP in input embedding
sinusoidal_encoding = cfg["sinusoidal_encoding"] # sinusoidal positional encoding
# Object Tracking module
print(f"======================================")
object_tracking_module = ObjectTracking(
yolo_weights_path=str(yolo_weight_path / "vidor_yolov5l.pt"),
deep_sort_model_dir=str(deepsort_weight_path),
config_path=str(cfg_path / "object_tracking.yaml"),
device=device,
)
yolov5_stride = object_tracking_module.yolov5_stride
# Head Tracking and Gaze Following modules
print(f"======================================")
head_tracking_module = HeadTracking(
crowd_human_weight_path=str(yolo_weight_path / "crowdhuman_yolov5m.pt"),
deep_sort_model_dir=str(deepsort_weight_path),
config_path=str(cfg_path / "object_tracking.yaml"),
device=device,
)
print(f"======================================")
gaze_following_module = GazeFollowing(
weight_path=str(gaze_following_weight_path),
config_path=str(cfg_path / "gaze_following.yaml"),
device=device,
)
matching_iou_thres = 0.7
matching_method = "hungarian"
# Feature backbone
print(f"======================================")
feature_backbone = FeatureExtractionResNet101(backbone_model_path, download=True, finetune=False).to(device)
feature_backbone.requires_grad_(False)
feature_backbone.eval()
print(f"Feature backbone loaded from {backbone_model_path}")
# load available objects and interactions
with Path("vidhoi_related/obj_categories.json").open("r") as f:
object_classes = json.load(f)
with Path("vidhoi_related/pred_categories.json").open("r") as f:
interaction_classes = json.load(f)
with Path("vidhoi_related/pred_split_categories.json").open("r") as f:
temp_dict = json.load(f)
spatial_class_idxes = temp_dict["spatial"]
action_class_idxes = temp_dict["action"]
num_object_classes = len(object_classes)
num_interaction_classes = len(interaction_classes)
num_spatial_classes = len(spatial_class_idxes)
num_action_classes = len(action_class_idxes)
num_interaction_classes_loss = num_interaction_classes
# Transformer setup
print(f"Transformer configs: {cfg}")
loss_type_dict = {"spatial_head": "bce", "action_head": "bce"}
separate_head_num = [num_spatial_classes, -1]
separate_head_name = ["spatial_head", "action_head"]
class_idxes_dict = {"spatial_head": spatial_class_idxes, "action_head": action_class_idxes}
loss_gt_dict = {"spatial_head": "spatial_gt", "action_head": "action_gt"}
sttran_gaze_model = STTranGazeCrossAttention(
num_interaction_classes=num_interaction_classes_loss,
obj_class_names=object_classes,
spatial_layer_num=sttran_enc_layer_num,
cross_layer_num=1,
temporal_layer_num=sttran_dec_layer_num - 1,
dim_transformer_ffn=dim_transformer_ffn,
d_gaze=512,
cross_sa=True,
cross_ffn=False,
global_token=global_token,
mlp_projection=mlp_projection,
sinusoidal_encoding=sinusoidal_encoding,
dropout=0,
word_vector_dir=sttran_word_vector_dir,
sliding_window=sttran_sliding_window,
separate_head=separate_head_num,
separate_head_name=separate_head_name,
)
# load model weights
sttran_gaze_model = sttran_gaze_model.to(device)
incompatibles = sttran_gaze_model.load_state_dict(torch.load(model_path))
sttran_gaze_model.eval()
print(f"STTranGaze loaded. Incompatible keys {incompatibles}")
# Load the video
dataset = VideoDatasetLoader(
source, transform=YOLOv5Transform(imgsz, yolov5_stride), additional_transform=STTranTransform(img_size=imgsz)
)
frame_num = dataset.frame_num
print(f"Video {source} loaded with {frame_num} frames")
# warmup tracker
object_tracking_module.clear()
frame, frame0, _, _, _ = next(iter(dataset))
object_tracking_module.warmup(frame.to(device), frame0)
print(f"Object tracker warmup finished.")
# output video writer
fourcc = "mp4v"
fps, w, h = round(dataset.fps), frame0.shape[1], frame0.shape[0] # not handle decimal fps
video_writer = cv2.VideoWriter(str(result_video_file), cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
# results for one video
detection_dict = defaultdict(list)
gaze_list = []
hoi_list = []
result_list = []
# FIFO queues for sliding window
frames_queue = deque(maxlen=sttran_sliding_window)
frame_ids_queue = deque(maxlen=sttran_sliding_window)
# iteration over the video, get object traces and human gazes
# NOTE only store the detections and gazes in keyframes into files, but the video contains all frames
hx_memory = {}
print(f"============ Inference Start... ============")
t = tqdm(enumerate(iter(dataset)), total=frame_num)
# t = iter(dataset)
for idx, batch in t:
frame, frame0, _, _, meta_info = batch
meta_info["original_shape"] = frame0.shape
# object tracking
bboxes, ids, labels, names, confs, _ = object_tracking_module.track_one(frame.to(device), frame0, draw=False)
# draw bbox
if len(ids) > 0:
frame_annotated = draw_bboxes(frame0.copy(), bboxes, ids, labels, names, confs)
else:
frame_annotated = frame0.copy()
# get human bboxes
human_idxes = np.array(labels) == 0
human_bboxes = np.array(bboxes)[human_idxes]
human_ids = np.array(ids)[human_idxes]
# head detection
h_bboxes, _, _, _, h_confs, _ = head_tracking_module.track_one(frame.to(device), frame0, draw=False)
# human-head association
head_bbox_dict = assign_human_head_frame(
h_bboxes, h_confs, human_bboxes, human_ids, matching_iou_thres, matching_method
)
# gaze following for each human
frame_gaze_dict = {}
for human_id, head_bbox in head_bbox_dict.items():
# no head found for this human_id
if len(head_bbox) == 0:
frame_gaze_dict[int(human_id)] = []
continue
# check hidden state memory
if human_id in hx_memory:
hidden_state = hx_memory[human_id]
else:
hidden_state = None
# gaze model forward
heatmap, _, hx, _, _, frame_annotated = gaze_following_module.detect_one(
frame0, head_bbox, hidden_state, draw=True, frame_to_draw=frame_annotated, id=human_id, arrow=True
)
# update hidden state memory
hx_memory[human_id] = (hx[0].detach(), hx[1].detach())
# process heatmap 64x64 (not include inout), store inout info separately
frame_gaze_dict[int(human_id)] = heatmap.tolist()
# store result for every second
if idx % fps == 0:
bboxes = [bbox.tolist() for bbox in bboxes]
detection_dict["bboxes"].append(bboxes)
detection_dict["ids"].append(ids)
detection_dict["labels"].append(labels)
detection_dict["confidences"].append(confs)
detection_dict["frame_ids"].append(meta_info["frame_count"])
gaze_list.append(frame_gaze_dict)
# write video frame
video_writer.write(frame_annotated)
# predict HOIs every second
if idx % fps == 0:
# generate sliding window
frames_queue.append(meta_info["additional"])
frame_ids_queue.append(meta_info["frame_count"])
sttran_frames = torch.cat(list(frames_queue)).to(device)
det_bboxes = detection_dict["bboxes"][-sttran_sliding_window:]
det_ids = detection_dict["ids"][-sttran_sliding_window:]
det_labels = detection_dict["labels"][-sttran_sliding_window:]
det_confidences = detection_dict["confidences"][-sttran_sliding_window:]
bboxes, ids, pred_labels, confidences = convert_annotation_frame_to_video(
det_bboxes, det_ids, det_labels, det_confidences
)
if len(bboxes) == 0:
# no detection
pair_idxes = []
im_idxes = []
else:
# Generate human-object pairs
pair_idxes, im_idxes = bbox_pair_generation(bboxes, pred_labels, 0)
detected = {
"bboxes": bboxes,
"pred_labels": pred_labels,
"ids": ids,
"confidences": confidences,
"pair_idxes": pair_idxes,
"im_idxes": im_idxes,
}
# fill the entry with detections
entry = fill_sttran_entry_inference(
sttran_frames,
detected,
gaze_list[-sttran_sliding_window:],
feature_backbone,
meta_info,
loss_type_dict,
class_idxes_dict,
loss_gt_dict,
device,
annotations=None,
human_label=0,
)
windows = construct_sliding_window(entry, sampling_mode, sttran_sliding_window, 0, None, gt=False)
entry, windows, windows_out, out_im_idxes, _ = generate_sliding_window_mask(entry, windows, None, "pair")
# only do model forward if any valid window exists
if len(windows) > 0:
# everything to GPU
entry["pair_idxes"] = entry["pair_idxes"].to(device)
for i in range(len(entry["full_heatmaps"])):
entry["full_heatmaps"][i] = entry["full_heatmaps"][i].to(device)
entry["pred_labels"] = entry["pred_labels"].to(device)
entry["windows"] = windows.to(device)
entry["windows_out"] = windows_out.to(device)
# forward
entry = sttran_gaze_model(entry)
# sigmoid or softmax
for head_name in loss_type_dict.keys():
if loss_type_dict[head_name] == "ce":
entry[head_name] = torch.softmax(entry[head_name], dim=-1)
else:
entry[head_name] = torch.sigmoid(entry[head_name])
# in inference, length prediction may != length gt
# len_preds = len(interactions_gt)
len_preds = len(entry[list(loss_type_dict.keys())[0]])
interaction_distribution = concat_separated_head(
entry, len_preds, loss_type_dict, class_idxes_dict, device, True
)
# process output
frame_ids = list(frame_ids_queue)
# window-wise result entry
out_im_idx = len(frame_ids) - 1
window_anno = {
"video_name": video_name, # video name
"frame_id": frame_ids[out_im_idx], # this frame id
}
if sampling_mode == "anticipation":
if idx + future >= frame_num:
window_anno["future_frame_id"] = ""
else:
window_anno["future_frame_id"] = f"{idx + fps * future:06d}"
window_prediction = {
"bboxes": [],
"pred_labels": [],
"confidences": [],
"pair_idxes": [],
"interaction_distribution": [],
}
# case 1, nothing detected in the full clip, result all []
if len(entry["bboxes"]) == 0:
pass
else:
det_out_idxes = entry["bboxes"][:, 0] == out_im_idx
# case 2, nothing detected in this window, result all []
if not det_out_idxes.any():
pass
else:
# something detected, fill object detection results
# NOTE det_idx_offset is the first bbox index in this window
det_idx_offset = det_out_idxes.nonzero(as_tuple=True)[0][0]
window_prediction["bboxes"] = entry["bboxes"][det_out_idxes, 1:].numpy().tolist()
window_prediction["pred_labels"] = entry["pred_labels"][det_out_idxes].cpu().numpy().tolist()
window_prediction["confidences"] = entry["confidences"][det_out_idxes].numpy().tolist()
window_prediction["ids"] = torch.LongTensor(entry["ids"])[det_out_idxes].numpy().tolist()
pair_out_idxes = entry["im_idxes"] == out_im_idx
# case 3, no human-object pair detected (no human or no object), pair_idxes and distribution []
if not pair_out_idxes.any():
pass
else:
# case 4, have everything
pair_idxes = entry["pair_idxes"][pair_out_idxes] - det_idx_offset
# handle interaction distributions
window_prediction["pair_idxes"] = pair_idxes.cpu().numpy().tolist()
window_prediction["interaction_distribution"] = interaction_distribution.cpu().numpy().tolist()
window_result = {**window_anno, **window_prediction}
result_list.append(window_result)
# print HOIs, only considering interaction scores
triplets_scores = generate_triplets_scores(
window_result["pair_idxes"],
[1.0] * len(window_result["confidences"]),
window_result["interaction_distribution"],
multiply=True,
top_k=100,
thres=hoi_thres,
)
s_hois = "-------------------------------\n"
s_hois += f"Frame {idx}/{frame_num}:\n"
for score, idx_pair, interaction_pred in triplets_scores:
subj_idx = window_result["pair_idxes"][idx_pair][0]
subj_cls = window_result["pred_labels"][subj_idx]
subj_name = object_classes[subj_cls]
subj_id = window_result["ids"][subj_idx]
obj_idx = window_result["pair_idxes"][idx_pair][1]
obj_cls = window_result["pred_labels"][obj_idx]
obj_name = object_classes[obj_cls]
obj_id = window_result["ids"][obj_idx]
interaction_name = interaction_classes[interaction_pred]
s_hois += f"{subj_name}{subj_id} - {interaction_name} - {obj_name}{obj_id}: {score}\n"
hoi_list.append(s_hois)
if print_hois:
print(s_hois)
# release video writer
video_writer.release()
# store detections and gazes
with trace_file.open("w") as f:
json.dump(detection_dict, f)
with gaze_file.open("w") as f:
json.dump(gaze_list, f)
# store HOIs
with result_file.open("w") as f:
json.dump(result_list, f)
with hoi_file.open("w") as f:
f.writelines(hoi_list)
if __name__ == "__main__":
opt = parse_opt()
main(opt)