Skip to content

Commit

Permalink
[Add] Yolov5 trt HRNet inference on video (#99)
Browse files Browse the repository at this point in the history
* [Add] Yolov5 trt HRNet inference on video

* Fixed a bug with int overflow on tracking

* Revisions before review

Co-authored-by: Giannis Pastaltzidis <gpastal@iti.gr>
  • Loading branch information
gpastal24 and gpastal authored Nov 27, 2022
1 parent 2de1176 commit 276290f
Show file tree
Hide file tree
Showing 17 changed files with 163 additions and 78 deletions.
163 changes: 99 additions & 64 deletions SimpleHRNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torchvision.transforms import transforms

from models.hrnet import HRNet
from models.poseresnet import PoseResNet
from models_.hrnet import HRNet
from models_.poseresnet import PoseResNet
# from models.detectors.YOLOv3 import YOLOv3 # import only when multi-person is enabled


Expand All @@ -28,10 +28,11 @@ def __init__(self,
return_heatmaps=False,
return_bounding_boxes=False,
max_batch_size=32,
yolo_model_def="./models/detectors/yolo/config/yolov3.cfg",
yolo_model_def="yolov5n",
yolo_class_path="./models/detectors/yolo/data/coco.names",
yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights",
device=torch.device("cpu")):
device=torch.device("cpu"),
enable_tensorrt=False):
"""
Initializes a new SimpleHRNet object.
HRNet (and YOLOv3) are initialized on the torch.device("device") and
Expand Down Expand Up @@ -83,43 +84,50 @@ def __init__(self,
self.yolo_class_path = yolo_class_path
self.yolo_weights_path = yolo_weights_path
self.device = device
self.enable_tensorrt = enable_tensorrt

if self.multiperson:
from models.detectors.YOLOv3 import YOLOv3
# if self.multiperson:
# from models.detectors.YOLOv3 import YOLOv3

if model_name in ('HRNet', 'hrnet'):
self.model = HRNet(c=c, nof_joints=nof_joints)
elif model_name in ('PoseResNet', 'poseresnet', 'ResNet', 'resnet'):
self.model = PoseResNet(resnet_size=c, nof_joints=nof_joints)
else:
raise ValueError('Wrong model name.')
if not self.enable_tensorrt:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
if 'model' in checkpoint:
self.model.load_state_dict(checkpoint['model'])
else:
self.model.load_state_dict(checkpoint)

checkpoint = torch.load(checkpoint_path, map_location=self.device)
if 'model' in checkpoint:
self.model.load_state_dict(checkpoint['model'])
else:
self.model.load_state_dict(checkpoint)
if 'cuda' in str(self.device):
print("device: 'cuda' - ", end="")

if 'cuda' in str(self.device):
print("device: 'cuda' - ", end="")
if 'cuda' == str(self.device):
# if device is set to 'cuda', all available GPUs will be used
print("%d GPU(s) will be used" % torch.cuda.device_count())
device_ids = None
else:
# if device is set to 'cuda:IDS', only that/those device(s) will be used
print("GPU(s) '%s' will be used" % str(self.device))
device_ids = [int(x) for x in str(self.device)[5:].split(',')]

if 'cuda' == str(self.device):
# if device is set to 'cuda', all available GPUs will be used
print("%d GPU(s) will be used" % torch.cuda.device_count())
device_ids = None
self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
elif 'cpu' == str(self.device):
print("device: 'cpu'")
else:
# if device is set to 'cuda:IDS', only that/those device(s) will be used
print("GPU(s) '%s' will be used" % str(self.device))
device_ids = [int(x) for x in str(self.device)[5:].split(',')]
raise ValueError('Wrong device name.')

self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
elif 'cpu' == str(self.device):
print("device: 'cpu'")
self.model = self.model.to(device)
self.model.eval()
else:
raise ValueError('Wrong device name.')

self.model = self.model.to(device)
self.model.eval()
from torch2trt import torch2trt,TRTModule
## Load the TRT module.
self.model = TRTModule()
self.model.load_state_dict(torch.load(checkpoint_path))
self.model.cuda().eval()

if not self.multiperson:
self.transform = transforms.Compose([
Expand All @@ -128,12 +136,12 @@ def __init__(self,
])

else:
self.detector = YOLOv3(model_def=yolo_model_def,
class_path=yolo_class_path,
weights_path=yolo_weights_path,
classes=('person',),
max_batch_size=self.max_batch_size,
device=device)
temp_ = self.yolo_model_def.split('.')
if len(temp_)>1 and temp_[1]=='engine':
self.detector = torch.hub.load('ultralytics/yolov5','custom', self.yolo_model_def)

else:
self.detector = torch.hub.load('ultralytics/yolov5', self.yolo_model_def, pretrained=True)
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((self.resolution[0], self.resolution[1])), # (height, width)
Expand Down Expand Up @@ -195,38 +203,53 @@ def _predict_single(self, image):
dtype=np.float32)

else:
detections = self.detector.predict_single(image)
detections = self.detector(image)
detections = detections.xyxy[0]
detections = detections[detections[:,4] >= 0.3] ## Should this check be removed?

detections = detections[detections[:,5] == 0.]
detections = detections.cpu().numpy()
nof_people = len(detections) if detections is not None else 0
boxes = np.empty((nof_people, 4), dtype=np.int32)
images = torch.empty((nof_people, 3, self.resolution[0], self.resolution[1])) # (height, width)
boxes = torch.empty((nof_people, 4),device=self.device)
images = torch.empty((nof_people, 3, self.resolution[0], self.resolution[1]),device = self.device) # (height, width)
heatmaps = np.zeros((nof_people, self.nof_joints, self.resolution[0] // 4, self.resolution[1] // 4),
dtype=np.float32)

if detections is not None:
for i, (x1, y1, x2, y2, conf, cls_conf, cls_pred) in enumerate(detections):
x1 = int(round(x1.item()))
x2 = int(round(x2.item()))
y1 = int(round(y1.item()))
y2 = int(round(y2.item()))

for i, (x1, y1, x2, y2, conf, cls_pred) in enumerate(detections):
# print(x1,x2,y1,y2)

x1 = int(round(x1))#int(round(x1.item()))
x2 = int(round(x2))
y1 = int(round(y1))
y2 = int(round(y2))
# print(x1,x2,y1,y2)
# Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14)
correction_factor = self.resolution[0] / self.resolution[1] * (x2 - x1) / (y2 - y1)
# correction_factor = 256 / 192 * (x2 - x1) / (y2 - y1)
# Using padding instead of just bbox enlargement. This should redice cross person keypoint detection.
if correction_factor > 1:
# increase y side
center = y1 + (y2 - y1) // 2
length = int(round((y2 - y1) * correction_factor))
y1 = max(0, center - length // 2)
y2 = min(image.shape[0], center + length // 2)
y1_new = int( center - length // 2)
y2_new = int( center + length // 2)
image_crop = image[y1:y2, x1:x2, ::-1]
pad = (int(abs(y1_new-y1))), int(abs(y2_new-y2))
image_crop = np.pad(image_crop,((pad), (0, 0), (0, 0)))
images[i] = self.transform(image_crop)
boxes[i]= torch.tensor([x1, y1_new, x2, y2_new])

elif correction_factor < 1:
# increase x side
center = x1 + (x2 - x1) // 2
length = int(round((x2 - x1) * 1 / correction_factor))
x1 = max(0, center - length // 2)
x2 = min(image.shape[1], center + length // 2)

boxes[i] = [x1, y1, x2, y2]
images[i] = self.transform(image[y1:y2, x1:x2, ::-1])
x1_new = int( center - length // 2)
x2_new = int( center + length // 2)
image_crop = image[y1:y2, x1:x2, ::-1]
pad = (abs(x1_new-x1)), int(abs(x2_new-x2))
image_crop = np.pad(image_crop,((0, 0), (pad), (0, 0)))
images[i] = self.transform(image_crop)
boxes[i]= torch.tensor([x1_new, y1, x2_new, y2])

if images.shape[0] > 0:
images = images.to(self.device)
Expand All @@ -243,19 +266,29 @@ def _predict_single(self, image):
for i in range(0, len(images), self.max_batch_size):
out[i:i + self.max_batch_size] = self.model(images[i:i + self.max_batch_size])

out = out.detach().cpu().numpy()
pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
# out = out.detach().cpu().numpy()
pts = torch.empty((out.shape[0], out.shape[1], 3), dtype=torch.float32,device=self.device)
# For each human, for each joint: y, x, confidence
for i, human in enumerate(out):
heatmaps[i] = human
for j, joint in enumerate(human):
pt = np.unravel_index(np.argmax(joint), (self.resolution[0] // 4, self.resolution[1] // 4))
# 0: pt_y / (height // 4) * (bb_y2 - bb_y1) + bb_y1
# 1: pt_x / (width // 4) * (bb_x2 - bb_x1) + bb_x1
# 2: confidences
pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (boxes[i][3] - boxes[i][1]) + boxes[i][1]
pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (boxes[i][2] - boxes[i][0]) + boxes[i][0]
pts[i, j, 2] = joint[pt]
# Re-written in torch, maybe it is faster but who knows
(b,indices)=torch.max(out,dim=2)
(b,indices)=torch.max(b,dim=2)

(c,indicesc)=torch.max(out,dim=3)
(c,indicesc)=torch.max(c,dim=2)
dims = (self.resolution[0]//4,self.resolution[1]//4)
dim1= torch.tensor(1. / dims[0],device=self.device)
dim2= torch.tensor(1. / dims[1],device=self.device)

for i in range(0,out.shape[0]):

# 0: pt_y / (height // 4) * (bb_y2 - bb_y1) + bb_y1
# 1: pt_x / (width // 4) * (bb_x2 - bb_x1) + bb_x1
# 2: confidences

pts[i, :, 0] = indicesc[i,:] * dim1 * (boxes[i][3] - boxes[i][1]) + boxes[i][1]
pts[i, :, 1] = indices[i,:] *dim2* (boxes[i][2] - boxes[i][0]) + boxes[i][0]
pts[i, :, 2] = c[i,:]
pts=pts.cpu().numpy()

else:
pts = np.empty((0, 0, 3), dtype=np.float32)
Expand All @@ -264,7 +297,7 @@ def _predict_single(self, image):
if self.return_heatmaps:
res.append(heatmaps)
if self.return_bounding_boxes:
res.append(boxes)
res.append(boxes.cpu().numpy())
res.append(pts)

if len(res) > 1:
Expand Down Expand Up @@ -301,8 +334,10 @@ def _predict_batch(self, images):
dtype=np.float32)

else:
image_detections = self.detector.predict(images)
image_detections = self.detector(images)
detections = image_detections.xyxy[0]

image_detections = detections[detections[:,5] == 0.]
base_index = 0
nof_people = int(np.sum([len(d) for d in image_detections if d is not None]))
boxes = np.empty((nof_people, 4), dtype=np.int32)
Expand Down
Binary file added misc/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added misc/__pycache__/utils.cpython-38.pyc
Binary file not shown.
Binary file added misc/__pycache__/visualization.cpython-38.pyc
Binary file not shown.
4 changes: 4 additions & 0 deletions misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,11 @@ def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
if in_vis_thre is not None:
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
e = e[ind]

e = e[e <=2^32 -1]

ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0

return ious


Expand Down
1 change: 0 additions & 1 deletion models/detectors/yolo
Submodule yolo deleted from 47b7c9
File renamed without changes.
Binary file added models_/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added models_/__pycache__/hrnet.cpython-38.pyc
Binary file not shown.
Binary file added models_/__pycache__/modules.cpython-38.pyc
Binary file not shown.
Binary file added models_/__pycache__/poseresnet.cpython-38.pyc
Binary file not shown.
File renamed without changes.
2 changes: 1 addition & 1 deletion models/hrnet.py → models_/hrnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from models.modules import BasicBlock, Bottleneck
from models_.modules import BasicBlock, Bottleneck


class StageModule(nn.Module):
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion models/poseresnet.py → models_/poseresnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from models.modules import BasicBlock, Bottleneck
from models_.modules import BasicBlock, Bottleneck


resnet_spec = {
Expand Down
30 changes: 19 additions & 11 deletions scripts/live-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_joints_set, image_resolution,
single_person, use_tiny_yolo, disable_tracking, max_batch_size, disable_vidgear, save_video, video_format,
video_framerate, device):
video_framerate, device,enable_tensorrt):
if device is not None:
device = torch.device(device)
else:
Expand Down Expand Up @@ -44,13 +44,13 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
video = CamGear(camera_id).start()

if use_tiny_yolo:
yolo_model_def="./models/detectors/yolo/config/yolov3-tiny.cfg"
yolo_class_path="./models/detectors/yolo/data/coco.names"
yolo_weights_path="./models/detectors/yolo/weights/yolov3-tiny.weights"
yolo_model_def="yolov5n.engine"
yolo_class_path=""
yolo_weights_path=""
else:
yolo_model_def="./models/detectors/yolo/config/yolov3.cfg"
yolo_class_path="./models/detectors/yolo/data/coco.names"
yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights"
yolo_model_def="yolov5n"
yolo_class_path=""
yolo_weights_path=""

model = SimpleHRNet(
hrnet_c,
Expand All @@ -64,21 +64,24 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
yolo_model_def=yolo_model_def,
yolo_class_path=yolo_class_path,
yolo_weights_path=yolo_weights_path,
device=device
device=device,
enable_tensorrt=enable_tensorrt
)

if not disable_tracking:
prev_boxes = None
prev_pts = None
prev_person_ids = None
next_person_id = 0

t_start = time.time()
while True:
t = time.time()

if filename is not None or disable_vidgear:
ret, frame = video.read()
if not ret:
t_end = time.time()
print("\n Total Time: " ,t_end-t_start)
break
if rotation_code is not None:
frame = cv2.rotate(frame, rotation_code)
Expand All @@ -98,6 +101,7 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
person_ids = np.arange(next_person_id, len(pts) + next_person_id, dtype=np.int32)
next_person_id = len(pts) + 1
else:
# print(boxes)
boxes, pts, person_ids = find_person_id_associations(
boxes=boxes, pts=pts, prev_boxes=prev_boxes, prev_pts=prev_pts, prev_person_ids=prev_person_ids,
next_person_id=next_person_id, pose_alpha=0.2, similarity_threshold=0.4, smoothing_alpha=0.1,
Expand All @@ -117,10 +121,12 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
frame = draw_points_and_skeleton(frame, pt, joints_dict()[hrnet_joints_set]['skeleton'], person_index=pid,
points_color_palette='gist_rainbow', skeleton_color_palette='jet',
points_palette_samples=10)
# for box in boxes:
# cv2.rectangle(frame,(box[0],box[1]),(box[2],box[3]),(255,255,255),2)

fps = 1. / (time.time() - t)
print('\rframerate: %f fps' % fps, end='')

print('\rframerate: %f fps, for %d person(s) ' % (fps,len(pts)), end='')
if has_display:
cv2.imshow('frame.png', frame)
k = cv2.waitKey(1)
Expand Down Expand Up @@ -181,5 +187,7 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo
"set to `cuda:IDS` to use one or more specific GPUs "
"(e.g. `cuda:0` `cuda:1,2`); "
"set to `cpu` to run on cpu.", type=str, default=None)
parser.add_argument("--enable_tensorrt", help="save output frames into a video.", action="store_true")

args = parser.parse_args()
main(**args.__dict__)
Loading

0 comments on commit 276290f

Please sign in to comment.