From 6bfcdaf5bcb006b945af9735883b198a54f62d4c Mon Sep 17 00:00:00 2001 From: Stefano Pini Date: Thu, 29 Dec 2022 18:17:05 +0000 Subject: [PATCH] Add TensorRT Support and YOLOv5 (#100) * [Add] Yolov5 trt HRNet inference on video (#99) * [Add] Yolov5 trt HRNet inference on video * Fixed a bug with int overflow on tracking * Revisions before review Co-authored-by: Giannis Pastaltzidis * Restore submodule * Rename models folder and remove pycache * Add .gitignore * Fix models imports * Refactor code, Update docstring * Fix YOLOv5 sizes and loading * Rename models to fix issue with YOLOv5 from torch hub * Fix issues with yolov5 and tensorrt, Update extract-keypoints script * Fix bug in yolov5.engine detections Fix bug in detection locations when tensorrt-converted yolov5 is used * Add Google Colab notebook * Update README.md * Update notebook Co-authored-by: gpastal24 <104021640+gpastal24@users.noreply.github.com> Co-authored-by: Giannis Pastaltzidis --- .gitignore | 160 +++++++ .gitmodules | 2 +- README.md | 35 +- SimpleHRNet.py | 174 +++++--- SimpleHRNet_notebook.ipynb | 537 ++++++++++++++++++++++++ datasets/LiveCamera.py | 8 +- misc/utils.py | 4 + {models => models_}/__init__.py | 0 {models => models_}/detectors/YOLOv3.py | 6 +- models_/detectors/YOLOv5.py | 103 +++++ {models => models_}/detectors/yolo | 0 {models => models_}/hrnet.py | 2 +- {models => models_}/modules.py | 0 {models => models_}/poseresnet.py | 2 +- scripts/export-tensorrt-model.py | 53 +++ scripts/extract-keypoints.py | 55 ++- scripts/live-demo.py | 61 ++- testing/Test.py | 2 +- training/Train.py | 2 +- 19 files changed, 1110 insertions(+), 96 deletions(-) create mode 100644 .gitignore create mode 100644 SimpleHRNet_notebook.ipynb rename {models => models_}/__init__.py (100%) rename {models => models_}/detectors/YOLOv3.py (97%) create mode 100644 models_/detectors/YOLOv5.py rename {models => models_}/detectors/yolo (100%) rename {models => models_}/hrnet.py (99%) rename {models => models_}/modules.py (100%) rename {models => models_}/poseresnet.py (98%) create mode 100644 scripts/export-tensorrt-model.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2dc53ca --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ diff --git a/.gitmodules b/.gitmodules index c6109a3..7be36f9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "models/detectors/yolo"] - path = models/detectors/yolo + path = models_/detectors/yolo url = https://github.com/eriklindernoren/PyTorch-YOLOv3 diff --git a/README.md b/README.md index 4ca78dc..36aac97 100644 --- a/README.md +++ b/README.md @@ -15,17 +15,20 @@ This repository provides: - A simple ``HRNet`` implementation in PyTorch (>=1.0) - compatible with official weights (``pose_hrnet_*``). - A simple class (``SimpleHRNet``) that loads the HRNet network for the human pose estimation, loads the pre-trained weights, and make human predictions on a single image or a batch of images. -- **NEW** Support for "SimpleBaselines" model based on ResNet - compatible with official weights (``pose_resnet_*``). -- **NEW** Support for multi-GPU inference. -- **NEW** Add option for using YOLOv3-tiny (faster, but less accurate person detection). -- **NEW** Add options for retrieving yolo bounding boxes and HRNet heatmaps. -- Multi-person support with +- Support for "SimpleBaselines" model based on ResNet - compatible with official weights (``pose_resnet_*``). +- Support for multi-GPU inference. +- Add options for retrieving yolo bounding boxes and HRNet heatmaps. +- **NEW** Multi-person support with [YOLOv3](https://github.com/eriklindernoren/PyTorch-YOLOv3/tree/47b7c912877ca69db35b8af3a38d6522681b3bb3) - (enabled by default). + (enabled by default), YOLOv3-tiny, or [YOLOv5](https://github.com/ultralytics/yolov5) by Ultralytics. - A reference code that runs a live demo reading frames from a webcam or a video file. - A relatively-simple code for training and testing the HRNet network. - A specific script for training the network on the COCO dataset. -- **NEW** A [Google Colab notebook](https://github.com/stefanopini/simple-HRNet/issues/84#issuecomment-908199736) showcasing how to use this repository - Sincere thanks to [@basicvisual](https://github.com/basicvisual) and [@wuyenlin](https://github.com/wuyenlin). +- **NEW** An updated [Jupyter Notebook](https://github.com/stefanopini/simple-HRNet/blob/master/SimpleHRNet_notebook.ipynb) compatible with Google Colab showcasing how to use this repository. + - Open In Colab [Click here](https://colab.research.google.com/github/stefanopini/simple-HRNet/blob/master/SimpleHRNet_notebook.ipynb) to open the notebook on Colab! + - Thanks to [@basicvisual](https://github.com/basicvisual) and [@wuyenlin](https://github.com/wuyenlin) for the initial notebook. +- **NEW** Support for TensorRT (thanks to [@gpastal24](https://github.com/gpastal24), see [#99](https://github.com/stefanopini/simple-HRNet/pull/99) and [#100](https://github.com/stefanopini/simple-HRNet/pull/100)). + If you are interested in **HigherHRNet**, please look at [*simple-HigherHRNet*](https://github.com/stefanopini/simple-HigherHRNet) @@ -113,6 +116,24 @@ For help: python scripts/extract-keypoints.py --help ``` +### Converting the model to TensorRT: + +Warning: require the installation of TensorRT (see Nvidia website) and onnx. +On some platforms, they can be installed with +``` +pip install tensorrt onnx +``` + +Converting in FP16: +``` +python scripts/export-tensorrt-model.py --device 0 --half +``` + +For help: +``` +python scripts/export-tensorrt-model.py --help +``` + ### Running the training script ``` diff --git a/SimpleHRNet.py b/SimpleHRNet.py index bd4e407..cdca44f 100644 --- a/SimpleHRNet.py +++ b/SimpleHRNet.py @@ -1,11 +1,12 @@ +import os + import cv2 import numpy as np import torch from torchvision.transforms import transforms -from models.hrnet import HRNet -from models.poseresnet import PoseResNet -# from models.detectors.YOLOv3 import YOLOv3 # import only when multi-person is enabled +from models_.hrnet import HRNet +from models_.poseresnet import PoseResNet class SimpleHRNet: @@ -28,10 +29,12 @@ def __init__(self, return_heatmaps=False, return_bounding_boxes=False, max_batch_size=32, - yolo_model_def="./models/detectors/yolo/config/yolov3.cfg", - yolo_class_path="./models/detectors/yolo/data/coco.names", - yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights", - device=torch.device("cpu")): + yolo_version='v3', + yolo_model_def="./models_/detectors/yolo/config/yolov3.cfg", + yolo_class_path="./models_/detectors/yolo/data/coco.names", + yolo_weights_path="./models_/detectors/yolo/weights/yolov3.weights", + device=torch.device("cpu"), + enable_tensorrt=False): """ Initializes a new SimpleHRNet object. HRNet (and YOLOv3) are initialized on the torch.device("device") and @@ -59,14 +62,23 @@ def __init__(self, max_batch_size (int): maximum batch size used in hrnet inference. Useless without multiperson=True. Default: 16 - yolo_model_def (str): path to yolo model definition file. - Default: "./models/detectors/yolo/config/yolov3.cfg" - yolo_class_path (str): path to yolo class definition file. - Default: "./models/detectors/yolo/data/coco.names" - yolo_weights_path (str): path to yolo pretrained weights file. - Default: "./models/detectors/yolo/weights/yolov3.weights.cfg" + yolo_version (str): version of YOLO. Supported versions: `v3`, `v5`. Used when multiperson is True. + Default: "v3" + yolo_model_def (str): path to yolo model definition file. Recommended values: + - `./models_/detectors/yolo/config/yolov3.cfg` if yolo_version is 'v3' + - `./models_/detectors/yolo/config/yolov3-tiny.cfg` if yolo_version is 'v3', to use tiny yolo + - yolov5 model name if yolo_version is 'v5', e.g. `yolov5m` (medium), `yolov5n` (nano) + - `yolov5m.engine` if yolo_version is 'v5', custom version (e.g. tensorrt model) + Default: "./models_/detectors/yolo/config/yolov3.cfg" + yolo_class_path (str): path to yolov3 class definition file. + Default: "./models_/detectors/yolo/data/coco.names" + yolo_weights_path (str): path to yolov3 pretrained weights file. + Default: "./models_/detectors/yolo/weights/yolov3.weights.cfg" device (:class:`torch.device`): the hrnet (and yolo) inference will be run on this device. Default: torch.device("cpu") + enable_tensorrt (bool): Enables tensorrt inference for HRnet. + If enabled, a `.engine` file is expected as `checkpoint_path`. + Default: False """ self.c = c @@ -79,13 +91,20 @@ def __init__(self, self.return_heatmaps = return_heatmaps self.return_bounding_boxes = return_bounding_boxes self.max_batch_size = max_batch_size + self.yolo_version = yolo_version self.yolo_model_def = yolo_model_def self.yolo_class_path = yolo_class_path self.yolo_weights_path = yolo_weights_path self.device = device + self.enable_tensorrt = enable_tensorrt if self.multiperson: - from models.detectors.YOLOv3 import YOLOv3 + if self.yolo_version == 'v3': + from models_.detectors.YOLOv3 import YOLOv3 + elif self.yolo_version == 'v5': + from models_.detectors.YOLOv5 import YOLOv5 + else: + raise ValueError('Unsopported YOLO version.') if model_name in ('HRNet', 'hrnet'): self.model = HRNet(c=c, nof_joints=nof_joints) @@ -94,32 +113,38 @@ def __init__(self, else: raise ValueError('Wrong model name.') - checkpoint = torch.load(checkpoint_path, map_location=self.device) - if 'model' in checkpoint: - self.model.load_state_dict(checkpoint['model']) - else: - self.model.load_state_dict(checkpoint) + if not self.enable_tensorrt: + checkpoint = torch.load(checkpoint_path, map_location=self.device) + if 'model' in checkpoint: + self.model.load_state_dict(checkpoint['model']) + else: + self.model.load_state_dict(checkpoint) - if 'cuda' in str(self.device): - print("device: 'cuda' - ", end="") + if 'cuda' in str(self.device): + print("device: 'cuda' - ", end="") - if 'cuda' == str(self.device): - # if device is set to 'cuda', all available GPUs will be used - print("%d GPU(s) will be used" % torch.cuda.device_count()) - device_ids = None + if 'cuda' == str(self.device): + # if device is set to 'cuda', all available GPUs will be used + print("%d GPU(s) will be used" % torch.cuda.device_count()) + device_ids = None + else: + # if device is set to 'cuda:IDS', only that/those device(s) will be used + print("GPU(s) '%s' will be used" % str(self.device)) + device_ids = [int(x) for x in str(self.device)[5:].split(',')] + + self.model = torch.nn.DataParallel(self.model, device_ids=device_ids) + elif 'cpu' == str(self.device): + print("device: 'cpu'") else: - # if device is set to 'cuda:IDS', only that/those device(s) will be used - print("GPU(s) '%s' will be used" % str(self.device)) - device_ids = [int(x) for x in str(self.device)[5:].split(',')] + raise ValueError('Wrong device name.') - self.model = torch.nn.DataParallel(self.model, device_ids=device_ids) - elif 'cpu' == str(self.device): - print("device: 'cpu'") + self.model = self.model.to(device) + self.model.eval() else: - raise ValueError('Wrong device name.') - - self.model = self.model.to(device) - self.model.eval() + from torch2trt import TRTModule + self.model = TRTModule() + self.model.load_state_dict(torch.load(checkpoint_path)) + self.model.cuda().eval() if not self.multiperson: self.transform = transforms.Compose([ @@ -128,12 +153,17 @@ def __init__(self, ]) else: - self.detector = YOLOv3(model_def=yolo_model_def, - class_path=yolo_class_path, - weights_path=yolo_weights_path, - classes=('person',), - max_batch_size=self.max_batch_size, - device=device) + if self.yolo_version == 'v3': + self.detector = YOLOv3(model_def=yolo_model_def, + class_path=yolo_class_path, + weights_path=yolo_weights_path, + classes=('person',), + max_batch_size=self.max_batch_size, + device=device) + else: + self.detector = YOLOv5(model_def=yolo_model_def, + device=device) + self.transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((self.resolution[0], self.resolution[1])), # (height, width) @@ -196,10 +226,10 @@ def _predict_single(self, image): else: detections = self.detector.predict_single(image) - nof_people = len(detections) if detections is not None else 0 boxes = np.empty((nof_people, 4), dtype=np.int32) - images = torch.empty((nof_people, 3, self.resolution[0], self.resolution[1])) # (height, width) + # boxes = torch.empty((nof_people, 4),device=self.device) + images = torch.empty((nof_people, 3, self.resolution[0], self.resolution[1]), device=self.device) # (height, width) heatmaps = np.zeros((nof_people, self.nof_joints, self.resolution[0] // 4, self.resolution[1] // 4), dtype=np.float32) @@ -212,21 +242,41 @@ def _predict_single(self, image): # Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14) correction_factor = self.resolution[0] / self.resolution[1] * (x2 - x1) / (y2 - y1) + + # Using padding instead of bbox enlargement, this should reduce cross-person keypoint detection if correction_factor > 1: # increase y side center = y1 + (y2 - y1) // 2 length = int(round((y2 - y1) * correction_factor)) - y1 = max(0, center - length // 2) - y2 = min(image.shape[0], center + length // 2) + x1_new = x1 + x2_new = x2 + y1_new = int(center - length // 2) + y2_new = int(center + length // 2) + pad = (int(abs(y1_new - y1))), int(abs(y2_new - y2)) + pad_tuple = (pad, (0, 0), (0, 0)) + elif correction_factor < 1: - # increase x side center = x1 + (x2 - x1) // 2 length = int(round((x2 - x1) * 1 / correction_factor)) - x1 = max(0, center - length // 2) - x2 = min(image.shape[1], center + length // 2) - - boxes[i] = [x1, y1, x2, y2] - images[i] = self.transform(image[y1:y2, x1:x2, ::-1]) + x1_new = int(center - length // 2) + x2_new = int(center + length // 2) + y1_new = y1 + y2_new = y2 + pad = (abs(x1_new - x1)), int(abs(x2_new - x2)) + pad_tuple = ((0, 0), pad, (0, 0)) + else: + x1_new = x1 + x2_new = x2 + y1_new = y1 + y2_new = y2 + pad_tuple = None + + image_crop = image[y1:y2, x1:x2, ::-1] + if pad_tuple is not None: + image_crop = np.pad(image_crop, pad_tuple) + images[i] = self.transform(image_crop) + boxes[i] = [x1_new, y1_new, x2_new, y2_new] + # boxes[i] = torch.tensor([x1_new, y1_new, x2_new, y2_new]) if images.shape[0] > 0: images = images.to(self.device) @@ -257,6 +307,26 @@ def _predict_single(self, image): pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (boxes[i][2] - boxes[i][0]) + boxes[i][0] pts[i, j, 2] = joint[pt] + # # Torch alternative, it could be faster + # pts = torch.empty((out.shape[0], out.shape[1], 3), dtype=torch.float32,device=self.device) + # # For each human, for each joint: y, x, confidence + # (b, indices) = torch.max(out, dim=2) + # (b, indices) = torch.max(b, dim=2) + # + # (c, indicesc) = torch.max(out, dim=3) + # (c, indicesc) = torch.max(c, dim=2) + # dims = (self.resolution[0]//4, self.resolution[1]//4) + # dim1 = torch.tensor(1. / dims[0], device=self.device) + # dim2 = torch.tensor(1. / dims[1], device=self.device) + # + # for i in range(0, out.shape[0]): + # pts[i, :, 0] = indicesc[i, :] * dim1 * (boxes[i][3] - boxes[i][1]) + boxes[i][1] + # pts[i, :, 1] = indices[i, :] * dim2 * (boxes[i][2] - boxes[i][0]) + boxes[i][0] + # pts[i, :, 2] = c[i, :] + # + # pts = pts.cpu().numpy() + # boxes = boxes.cpu().numpy() + else: pts = np.empty((0, 0, 3), dtype=np.float32) @@ -321,6 +391,8 @@ def _predict_batch(self, images): # Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14) correction_factor = self.resolution[0] / self.resolution[1] * (x2 - x1) / (y2 - y1) + + # TODO Use padding instead of bbox enlargement here too if correction_factor > 1: # increase y side center = y1 + (y2 - y1) // 2 diff --git a/SimpleHRNet_notebook.ipynb b/SimpleHRNet_notebook.ipynb new file mode 100644 index 0000000..6833c63 --- /dev/null +++ b/SimpleHRNet_notebook.ipynb @@ -0,0 +1,537 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "collapsed_sections": [ + "x1P13nZeR3Xj", + "HqHg_VATg6CO", + "ZWUN1C5RgGYS" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Simple HRNet\n", + "This is a light Google Colab notebook showing how to use the [simple-HRNet](https://github.com/stefanopini/simple-HRNet) repository.\n", + "\n", + "It includes the conversion to TensorRT and a test of the converted model.\n", + "Please skip the section \"TensorRT\" if not interested.\n", + "\n", + "Initial idea of running on Google Colab by @basicvisual, initial implementation by @wuyenlin (see [issue #84](https://github.com/stefanopini/simple-HRNet/issues/84))." + ], + "metadata": { + "id": "xZqqnmmNfX1d" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Pytorch" + ], + "metadata": { + "id": "ZFihjwzqhA04" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Clone the repo and install the dependencies" + ], + "metadata": { + "id": "X_ugGAxdd6Hu" + } + }, + { + "cell_type": "code", + "source": [ + "# clone the repo\n", + "!git clone https://github.com/stefanopini/simple-HRNet.git" + ], + "metadata": { + "id": "FIecXpzEY7IJ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%cd simple-HRNet\n", + "!pwd" + ], + "metadata": { + "id": "JDNRl8a8dl7Z" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# install requirements\n", + "!pip install -r requirements.txt" + ], + "metadata": { + "id": "FGsHqGPNdbHt" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# install vlc to get video codecs\n", + "!apt install vlc" + ], + "metadata": { + "id": "qMynH2IPebr8" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Add yolov3\n", + "Clone yolov3 for multiprocessing support. This can be skipped for single-person applications or if you plan to use YOLO v5 by Ultralytics." + ], + "metadata": { + "id": "x1P13nZeR3Xj" + } + }, + { + "cell_type": "code", + "source": [ + "# download git submodules\n", + "!git submodule update --init --recursive" + ], + "metadata": { + "id": "yqf7BRGWRtUV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%cd /content/simple-HRNet/models_/detectors/yolo\n", + "!pip install -q -r requirements.txt\n", + "\n", + "%cd /content/simple-HRNet" + ], + "metadata": { + "id": "vS9cz49gSJeG" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%cd /content/simple-HRNet/models_/detectors/yolo/weights\n", + "!sh download_weights.sh\n", + "\n", + "%cd /content/simple-HRNet" + ], + "metadata": { + "id": "8v-RpWGwSM7V" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Download HRNet pre-trained weights and test video\n", + "\n", + "Download any of the supported official weights listed [here](https://github.com/stefanopini/simple-HRNet/#installation-instructions).\n", + "\n", + "In the following, we download the weights `pose_hrnet_w48_384x288.pth` from the official Drive link.\n", + "Download of other weights (e.g. `pose_hrnet_w32_256x192.pth`) as well as weights from private Drives is supported too." + ], + "metadata": { + "id": "HqHg_VATg6CO" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install --upgrade --no-cache-dir gdown" + ], + "metadata": { + "id": "pKFdWLLUXyZu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# download weights\n", + "\n", + "# create weights folder\n", + "%cd /content/simple-HRNet\n", + "!mkdir weights\n", + "%cd /content/simple-HRNet/weights\n", + "\n", + "# download weights pose_hrnet_w48_384x288.pth\n", + "!gdown 1UoJhTtjHNByZSm96W3yFTfU5upJnsKiS\n", + "\n", + "# download weights pose_hrnet_w32_256x192.pth\n", + "!gdown 1zYC7go9EV0XaSlSBjMaiyE_4TcHc_S38\n", + "\n", + "# download weights pose_hrnet_w32_256x256.pth\n", + "!gdown 1_wn2ifmoQprBrFvUCDedjPON4Y6jsN-v\n", + "\n", + "# # download weights from your own Google Drive\n", + "# from glob import glob\n", + "# from google.colab import drive\n", + "# drive.mount('/content/drive')\n", + "# w_list = glob(\"/content/drive//*.pth\")\n", + "# if not w_list:\n", + "# raise FileNotFoundError(\"You haven't downloaded any pre-trained weights!\")\n", + "\n", + "%cd /content/simple-HRNet" + ], + "metadata": { + "id": "3LURZ12cfCcU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# download a publicly available video (or just get your own)\n", + "!wget https://commondatastorage.googleapis.com/gtv-videos-bucket/sample/WeAreGoingOnBullrun.mp4" + ], + "metadata": { + "id": "OLIrIc14eUPM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Test the API\n" + ], + "metadata": { + "id": "vcv0B2P7UTxT" + } + }, + { + "cell_type": "code", + "source": [ + "import cv2\n", + "import requests\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "from skimage import io\n", + "from PIL import Image\n", + "from SimpleHRNet import SimpleHRNet\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "# # singleperson, COCO weights\n", + "# model = SimpleHRNet(48, 17, \"./weights/pose_hrnet_w48_384x288.pth\", multiperson=False, device=device)\n", + "\n", + "# # multiperson w/ YOLOv3, COCO weights\n", + "# model = SimpleHRNet(48, 17, \"./weights/pose_hrnet_w48_384x288.pth\", device=device)\n", + "\n", + "# # multiperson w/ YOLOv3, COCO weights, small model\n", + "# model = SimpleHRNet(32, 17, \"./weights/pose_hrnet_w32_256x192.pth\", device=device)\n", + "\n", + "# # multiperson w/ YOLOv3, MPII weights\n", + "# model = SimpleHRNet(32, 16, \"./weights/pose_hrnet_w32_256x256.pth\", device=device)\n", + "\n", + "# # multiperson w/ YOLOv5 (medium), COCO weights\n", + "# model = SimpleHRNet(48, 17, \"./weights/pose_hrnet_w48_384x288.pth\", yolo_version='v5', yolo_model_def='yolov5m', device=device)\n", + "\n", + "# multiperson w/ YOLOv5 nano, COCO weights, small model\n", + "model = SimpleHRNet(32, 17, \"./weights/pose_hrnet_w32_256x192.pth\", yolo_version='v5', yolo_model_def='yolov5n', device=device)\n", + "\n", + "url = 'http://images.cocodataset.org/val2017/000000097278.jpg'\n", + "im = Image.open(requests.get(url, stream=True).raw)\n", + "image = io.imread(url)\n", + "\n", + "joints = model.predict(image)" + ], + "metadata": { + "id": "xCXrjhfJUR5C" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%matplotlib inline\n", + "from misc.visualization import joints_dict\n", + "\n", + "def plot_joints(ax, output):\n", + " bones = joints_dict()[\"coco\"][\"skeleton\"]\n", + " # bones = joints_dict()[\"mpii\"][\"skeleton\"]\n", + "\n", + " for bone in bones:\n", + " xS = [output[:,bone[0],1], output[:,bone[1],1]]\n", + " yS = [output[:,bone[0],0], output[:,bone[1],0]]\n", + " ax.plot(xS, yS, linewidth=3, c=(0,0.3,0.7))\n", + " ax.scatter(joints[:,:,1],joints[:,:,0], s=20, c='r')\n", + "\n", + "fig = plt.figure(figsize=(60/2.54, 30/2.54))\n", + "ax = fig.add_subplot(121)\n", + "ax.imshow(Image.open(requests.get(url, stream=True).raw))\n", + "ax = fig.add_subplot(122)\n", + "ax.imshow(Image.open(requests.get(url, stream=True).raw))\n", + "plot_joints(ax, joints)\n", + "plt.show()" + ], + "metadata": { + "id": "aYNkSzCGUqMF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Test the live script\n", + "This step can be skipped if interested in the TensorRT conversion." + ], + "metadata": { + "id": "ZWUN1C5RgGYS" + } + }, + { + "cell_type": "code", + "source": [ + "# # test the live script with default params (multiperson with yolo v3)\n", + "# !python ./scripts/live-demo.py --filename WeAreGoingOnBullrun.mp4 --save_video\n", + "\n", + "# # test the live script with tiny yolo (v3)\n", + "# !python ./scripts/live-demo.py --filename WeAreGoingOnBullrun.mp4 --save_video --use_tiny_yolo\n", + "\n", + "# # test the live script with yolo v5\n", + "# !python ./scripts/live-demo.py --filename WeAreGoingOnBullrun.mp4 --save_video --yolo_version v5\n", + "\n", + "# test the live script with tiny yolo v5 (tensorrt yolo v5)\n", + "!python ./scripts/live-demo.py --filename WeAreGoingOnBullrun.mp4 --save_video --yolo_version v5 --use_tiny_yolo" + ], + "metadata": { + "id": "VEPfVe2bg1dS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now check out the video output.avi\n" + ], + "metadata": { + "id": "RsTTv7A5gGvF" + } + }, + { + "cell_type": "markdown", + "source": [ + "## TensorRT\n", + "This section install TensorRT 8.5, converts the model to TensorRT (.engine) and tests the converted model.\n", + "\n", + "Tested with TensorRT 8.5.1-1+cuda11.8 and python package tensorrt 8.5.1.7 ." + ], + "metadata": { + "id": "YHj3FQEyf1yD" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Install TensorRT\n", + "A GPU is needed for this step. Please change the runtime type to \"GPU\".\n" + ], + "metadata": { + "id": "VsFWYxaNc-gl" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LsAlxRGVXhrt" + }, + "outputs": [], + "source": [ + "# check a GPU runtime is selected\n", + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "source": [ + "%%bash\n", + "wget https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/nvidia-machine-learning-repo-ubuntu1804_1.0.0-1_amd64.deb\n", + "\n", + "dpkg -i nvidia-machine-learning-repo-*.deb\n", + "apt-get update\n", + "\n", + "sudo apt-get install libnvinfer8 python3-libnvinfer" + ], + "metadata": { + "id": "9vZ35qN5XkHE" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# check TensorRT version\n", + "print(\"TensorRT version: \")\n", + "!dpkg -l | grep nvinfer" + ], + "metadata": { + "id": "GlGh_J2WYH8u" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# install TensorRT for python\n", + "!pip install tensorrt" + ], + "metadata": { + "id": "nhzVoykoYAWJ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# clone the converion tool torch2trt\n", + "%cd /content\n", + "!git clone https://github.com/NVIDIA-AI-IOT/torch2trt" + ], + "metadata": { + "id": "NUR0P_HklFbz" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# install torch2trt\n", + "%cd /content/torch2trt\n", + "!python setup.py install" + ], + "metadata": { + "id": "Y97nln2AX35c" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%cd /content/simple-HRNet" + ], + "metadata": { + "id": "UC-xqiy5X5vk" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Export the model with tensorrt" + ], + "metadata": { + "id": "I2u6Xn72eEBE" + } + }, + { + "cell_type": "code", + "source": [ + "# Convert the smaller HRNet model to TensorRT - it may take a while...\n", + "!python scripts/export-tensorrt-model.py --half \\\n", + " --weights \"./weights/pose_hrnet_w32_256x192.pth\" --hrnet_c 32 --image_resolution '(256, 192)'" + ], + "metadata": { + "id": "S57JsLacdnoF" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "#### [Optional] Export yolov5 with TensorRT" + ], + "metadata": { + "id": "ckdDXNJzmxt_" + } + }, + { + "cell_type": "code", + "source": [ + "# Optional - Convert yolov5 (nano) to tensorrt too\n", + "!python /root/.cache/torch/hub/ultralytics_yolov5_master/export.py --weights yolov5n.pt --include engine --device 0 --half" + ], + "metadata": { + "id": "3Hls1HlCl44F" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Test the tensorrt model" + ], + "metadata": { + "id": "npgGj4cGemZd" + } + }, + { + "cell_type": "code", + "source": [ + "# Run inference with the converted TensorRT model\n", + "!python scripts/live-demo.py --enable_tensorrt --filename=WeAreGoingOnBullrun.mp4 --hrnet_weights='weights/hrnet_trt.engine' \\\n", + " --hrnet_c 32 --image_resolution \"(256, 192)\" --yolo_version v5 --use_tiny_yolo --save_video\n" + ], + "metadata": { + "id": "LnIpbqV0fVps" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now check out the video output.avi\n" + ], + "metadata": { + "id": "WbQk0PeNnM5-" + } + } + ] +} \ No newline at end of file diff --git a/datasets/LiveCamera.py b/datasets/LiveCamera.py index fe8daed..3c2dc03 100644 --- a/datasets/LiveCamera.py +++ b/datasets/LiveCamera.py @@ -3,7 +3,7 @@ import torch from torch.utils.data import Dataset from torchvision import transforms -from models.detectors.YOLOv3 import YOLOv3 +from models_.detectors.YOLOv3 import YOLOv3 class LiveCameraDataset(Dataset): @@ -27,9 +27,9 @@ def __init__(self, camera_id=0, epoch_length=1, resolution=(384, 288), interpola ]) else: - self.detector = YOLOv3(model_def="./models/detectors/yolo/config/yolov3.cfg", - class_path="./models/detectors/yolo/data/coco.names", - weights_path="./models/detectors/yolo/weights/yolov3.weights", + self.detector = YOLOv3(model_def="./models_/detectors/yolo/config/yolov3.cfg", + class_path="./models_/detectors/yolo/data/coco.names", + weights_path="./models_/detectors/yolo/weights/yolov3.weights", classes=('person',), device=device) self.transform = transforms.Compose([ diff --git a/misc/utils.py b/misc/utils.py index b425da9..08fff48 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -360,7 +360,11 @@ def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None): if in_vis_thre is not None: ind = list(vg > in_vis_thre) and list(vd > in_vis_thre) e = e[ind] + + e = e[e <=2^32 -1] + ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0 + return ious diff --git a/models/__init__.py b/models_/__init__.py similarity index 100% rename from models/__init__.py rename to models_/__init__.py diff --git a/models/detectors/YOLOv3.py b/models_/detectors/YOLOv3.py similarity index 97% rename from models/detectors/YOLOv3.py rename to models_/detectors/YOLOv3.py index 63380f3..62d7959 100644 --- a/models/detectors/YOLOv3.py +++ b/models_/detectors/YOLOv3.py @@ -7,7 +7,7 @@ import torch from torchvision.transforms import transforms -sys.path.append(os.path.join(os.getcwd(), 'models', 'detectors', 'yolo')) +sys.path.append(os.path.join(os.getcwd(), 'models_', 'detectors', 'yolo')) from .yolo.models import Darknet from .yolo.utils.utils import load_classes, non_max_suppression @@ -29,10 +29,10 @@ def letterbox(img, new_shape=416, color=(127.5, 127.5, 127.5), mode='auto'): ratio = max(new_shape) / max(shape) # ratio = new / old new_unpad = (int(round(shape[1] * ratio)), int(round(shape[0] * ratio))) - if mode is 'auto': # minimum rectangle + if mode == 'auto': # minimum rectangle dw = np.mod(new_shape - new_unpad[0], 32) / 2 # width padding dh = np.mod(new_shape - new_unpad[1], 32) / 2 # height padding - elif mode is 'square': # square + elif mode == 'square': # square dw = (new_shape - new_unpad[0]) / 2 # width padding dh = (new_shape - new_unpad[1]) / 2 # height padding else: diff --git a/models_/detectors/YOLOv5.py b/models_/detectors/YOLOv5.py new file mode 100644 index 0000000..68e13f0 --- /dev/null +++ b/models_/detectors/YOLOv5.py @@ -0,0 +1,103 @@ +import os + +import cv2 +import numpy as np +import torch + + +# from https://github.com/ultralytics/yolov5 +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding + elif scaleFill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return im, ratio, (dw, dh) + + +class YOLOv5: + def __init__(self, + model_def='', + model_folder='./models_/detectors/yolov5', + image_resolution=(640, 640), + conf_thres=0.3, + device=torch.device('cpu')): + + self.model_def = model_def + self.model_folder = model_folder + self.image_resolution = image_resolution + self.conf_thres = conf_thres + self.device = device + self.trt_model = self.model_def.endswith('.engine') + + # Set up model + if self.trt_model: + # if the yolo model ends with 'engine', it is loaded as a custom YOLOv5 pre-trained model + print(f"Loading custom yolov5 model {self.model_def}") + self.model = torch.hub.load('ultralytics/yolov5', 'custom', self.model_def) + else: + # load the pre-trained YOLOv5 in a pre-defined folder + if not os.path.exists(self.model_folder): + os.makedirs(self.model_folder) + self.model = torch.hub.load('ultralytics/yolov5', self.model_def, pretrained=True) + + self.model = self.model.to(self.device) + self.model.eval() # Set in evaluation mode + + def predict_single(self, image, color_mode='BGR'): + image = image.copy() + if self.trt_model: + # when running with TensorRT, the image must have fixed size + image, (ratiow, ratioh), (dw, dh) = letterbox(image, self.image_resolution, stride=self.model.stride, + auto=False, scaleFill=False) # padded resize + + if color_mode == 'BGR': + # all YOLO models expect RGB + # See https://github.com/ultralytics/yolov5/issues/9913#issuecomment-1290736061 and + # https://github.com/ultralytics/yolov5/blob/8ca182613499c323a411f559b7b5ea072122c897/models/common.py#L662 + image = image[..., ::-1] + + with torch.no_grad(): + detections = self.model(image) + + detections = detections.xyxy[0] + detections = detections[detections[:, 4] >= self.conf_thres] + + detections = detections[detections[:, 5] == 0.] # person + + # adding a fake class confidence to maintain compatibility with YOLOv3 + detections = torch.cat((detections[:, :5], detections[:, 4:5], detections[:, 5:]), dim=1) + + if self.trt_model: + # account for the image resize fixing the xyxy locations + detections[:, [0, 2]] = (detections[:, [0, 2]] - dw) / ratiow + detections[:, [1, 3]] = (detections[:, [1, 3]] - dh) / ratioh + + return detections + + def predict(self, images, color_mode='BGR'): + raise NotImplementedError("Not currently supported.") diff --git a/models/detectors/yolo b/models_/detectors/yolo similarity index 100% rename from models/detectors/yolo rename to models_/detectors/yolo diff --git a/models/hrnet.py b/models_/hrnet.py similarity index 99% rename from models/hrnet.py rename to models_/hrnet.py index 3a079ce..830992d 100644 --- a/models/hrnet.py +++ b/models_/hrnet.py @@ -1,6 +1,6 @@ import torch from torch import nn -from models.modules import BasicBlock, Bottleneck +from models_.modules import BasicBlock, Bottleneck class StageModule(nn.Module): diff --git a/models/modules.py b/models_/modules.py similarity index 100% rename from models/modules.py rename to models_/modules.py diff --git a/models/poseresnet.py b/models_/poseresnet.py similarity index 98% rename from models/poseresnet.py rename to models_/poseresnet.py index 96b04c5..6e4dd70 100644 --- a/models/poseresnet.py +++ b/models_/poseresnet.py @@ -1,6 +1,6 @@ import torch from torch import nn -from models.modules import BasicBlock, Bottleneck +from models_.modules import BasicBlock, Bottleneck resnet_spec = { diff --git a/scripts/export-tensorrt-model.py b/scripts/export-tensorrt-model.py new file mode 100644 index 0000000..e2d68f5 --- /dev/null +++ b/scripts/export-tensorrt-model.py @@ -0,0 +1,53 @@ +import argparse +import ast +import os +import sys + +import torch +from torch2trt import torch2trt, TRTModule + +sys.path.insert(1, os.getcwd()) +from models_.hrnet import HRNet + + +def convert_to_trt(args): + """ + TensorRT conversion function for the HRNet models using torch2trt. + Requires the definition of the image resolution and the max batch size, supports FP16 mode (half precision). + """ + pose = HRNet(args.hrnet_c, 17) + + pose.load_state_dict(torch.load(args.weights)) + pose.cuda().eval() + + image_resolution = ast.literal_eval(args.image_resolution) + x = torch.ones(1, 3, image_resolution[0], image_resolution[1]).cuda() + print("Starting conversion to TensorRT with torch2trt...") + net_trt = torch2trt(pose, [x], max_batch_size=args.batch_size, fp16_mode=args.half) + torch.save(net_trt.state_dict(), args.output_path) + print(f"Conversion to TensorRT completed! Model saved at {args.output_path}") + + +def parse_opt(): + """Parses the arguments for the trt conversion.""" + parser = argparse.ArgumentParser() + parser.add_argument("--weights", "-w", help="the model weights file", type=str, + default='./weights/pose_hrnet_w48_384x288.pth') + parser.add_argument("--hrnet_c", "-c", help="HRNet channels, either 32 or 48 (default)", type=int, default=48) + parser.add_argument("--hrnet_j", "-j", help="HRNet number of joints, 17 (default)", type=int, default=17) + parser.add_argument("--image_resolution", "-r", help="image resolution, 256x192 or 384x288 (default)", type=str, + default="(384, 288)") + parser.add_argument("--batch_size", "-b", help="maximum batch size for trt", type=int, default=16) + parser.add_argument('--half', action='store_true', help='FP16 half-precision export') + parser.add_argument("--output_path", help="output path, default ./weights/hrnet_trt.engine", type=str, + default="./weights/hrnet_trt.engine") + return parser.parse_args() + + +def main(): + args = parse_opt() + convert_to_trt(args) + + +if __name__ == '__main__': + main() diff --git a/scripts/extract-keypoints.py b/scripts/extract-keypoints.py index 2a4a04c..0cd5361 100644 --- a/scripts/extract-keypoints.py +++ b/scripts/extract-keypoints.py @@ -13,8 +13,9 @@ from misc.visualization import check_video_rotation -def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resolution, single_person, use_tiny_yolo, - max_batch_size, csv_output_filename, csv_delimiter, json_output_filename, device): +def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resolution, single_person, yolo_version, + use_tiny_yolo, max_batch_size, csv_output_filename, csv_delimiter, json_output_filename, device, + enable_tensorrt): if device is not None: device = torch.device(device) else: @@ -43,14 +44,28 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol fd = open(json_output_filename, 'wt') json_data = {} - if use_tiny_yolo: - yolo_model_def = "./models/detectors/yolo/config/yolov3-tiny.cfg" - yolo_class_path = "./models/detectors/yolo/data/coco.names" - yolo_weights_path = "./models/detectors/yolo/weights/yolov3-tiny.weights" + if yolo_version == 'v3': + if use_tiny_yolo: + yolo_model_def = "./models_/detectors/yolo/config/yolov3-tiny.cfg" + yolo_weights_path = "./models_/detectors/yolo/weights/yolov3-tiny.weights" + else: + yolo_model_def = "./models_/detectors/yolo/config/yolov3.cfg" + yolo_weights_path = "./models_/detectors/yolo/weights/yolov3.weights" + yolo_class_path = "./models_/detectors/yolo/data/coco.names" + elif yolo_version == 'v5': + # YOLOv5 comes in different sizes: n(ano), s(mall), m(edium), l(arge), x(large) + if use_tiny_yolo: + yolo_model_def = "yolov5n" # this is the nano version + else: + yolo_model_def = "yolov5m" # this is the medium version + if enable_tensorrt: + yolo_trt_filename = yolo_model_def + ".engine" + if os.path.exists(yolo_trt_filename): + yolo_model_def = yolo_trt_filename + yolo_class_path = "" + yolo_weights_path = "" else: - yolo_model_def = "./models/detectors/yolo/config/yolov3.cfg" - yolo_class_path = "./models/detectors/yolo/data/coco.names" - yolo_weights_path = "./models/detectors/yolo/weights/yolov3.weights" + raise ValueError('Unsopported YOLO version.') model = SimpleHRNet( hrnet_c, @@ -60,18 +75,23 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol resolution=image_resolution, multiperson=not single_person, max_batch_size=max_batch_size, + yolo_version=yolo_version, yolo_model_def=yolo_model_def, yolo_class_path=yolo_class_path, yolo_weights_path=yolo_weights_path, - device=device + device=device, + enable_tensorrt=enable_tensorrt ) index = 0 + t_start = time.time() while True: t = time.time() ret, frame = video.read() if not ret: + t_end = time.time() + print("\n Total Time: ", t_end - t_start) break if rotation_code is not None: frame = cv2.rotate(frame, rotation_code) @@ -123,7 +143,7 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol type=str, default=None) parser.add_argument("--filename", "-f", help="open the specified video", type=str, default=None) - parser.add_argument("--hrnet_m", "-m", help="network model - HRNet or PoseResNet", type=str, default='HRNet') + parser.add_argument("--hrnet_m", "-m", help="network model - 'HRNet' or 'PoseResNet'", type=str, default='HRNet') parser.add_argument("--hrnet_c", "-c", help="hrnet parameters - number of channels (if model is HRNet), " "resnet size (if model is PoseResNet)", type=int, default=48) parser.add_argument("--hrnet_j", "-j", help="hrnet parameters - number of joints", type=int, default=17) @@ -134,8 +154,13 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol help="disable the multiperson detection (YOLOv3 or an equivalen detector is required for" "multiperson detection)", action="store_true") + parser.add_argument("--yolo_version", + help="Use the specified version of YOLO. Supported versions: `v3` (default), `v5`.", + type=str, default="v3") parser.add_argument("--use_tiny_yolo", - help="Use YOLOv3-tiny in place of YOLOv3 (faster person detection). Ignored if --single_person", + help="Use YOLOv3-tiny in place of YOLOv3 (faster person detection) if `yolo_version` is `v3`." + "Use YOLOv5n(ano) in place of YOLOv5m(edium) if `yolo_version` is `v5`." + "Ignored if --single_person", action="store_true") parser.add_argument("--max_batch_size", help="maximum batch size used for inference", type=int, default=16) parser.add_argument("--csv_output_filename", help="filename of the csv that will be written.", @@ -148,5 +173,11 @@ def main(format, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, image_resol "set to `cuda:IDS` to use one or more specific GPUs " "(e.g. `cuda:0` `cuda:1,2`); " "set to `cpu` to run on cpu.", type=str, default=None) + parser.add_argument("--enable_tensorrt", + help="Enables tensorrt inference for HRnet. If enabled, a `.engine` file is expected as " + "weights (`--hrnet_weights`). This option should be used only after the HRNet engine " + "file has been generated using the script `scripts/export-tensorrt-model.py`.", + action='store_true') + args = parser.parse_args() main(**args.__dict__) diff --git a/scripts/live-demo.py b/scripts/live-demo.py index da4bca1..6bc6c34 100644 --- a/scripts/live-demo.py +++ b/scripts/live-demo.py @@ -13,9 +13,10 @@ from misc.visualization import draw_points, draw_skeleton, draw_points_and_skeleton, joints_dict, check_video_rotation from misc.utils import find_person_id_associations + def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_joints_set, image_resolution, - single_person, use_tiny_yolo, disable_tracking, max_batch_size, disable_vidgear, save_video, video_format, - video_framerate, device): + single_person, yolo_version, use_tiny_yolo, disable_tracking, max_batch_size, disable_vidgear, save_video, + video_format, video_framerate, device, enable_tensorrt): if device is not None: device = torch.device(device) else: @@ -43,14 +44,28 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo else: video = CamGear(camera_id).start() - if use_tiny_yolo: - yolo_model_def="./models/detectors/yolo/config/yolov3-tiny.cfg" - yolo_class_path="./models/detectors/yolo/data/coco.names" - yolo_weights_path="./models/detectors/yolo/weights/yolov3-tiny.weights" + if yolo_version == 'v3': + if use_tiny_yolo: + yolo_model_def = "./models_/detectors/yolo/config/yolov3-tiny.cfg" + yolo_weights_path = "./models_/detectors/yolo/weights/yolov3-tiny.weights" + else: + yolo_model_def = "./models_/detectors/yolo/config/yolov3.cfg" + yolo_weights_path = "./models_/detectors/yolo/weights/yolov3.weights" + yolo_class_path = "./models_/detectors/yolo/data/coco.names" + elif yolo_version == 'v5': + # YOLOv5 comes in different sizes: n(ano), s(mall), m(edium), l(arge), x(large) + if use_tiny_yolo: + yolo_model_def = "yolov5n" # this is the nano version + else: + yolo_model_def = "yolov5m" # this is the medium version + if enable_tensorrt: + yolo_trt_filename = yolo_model_def + ".engine" + if os.path.exists(yolo_trt_filename): + yolo_model_def = yolo_trt_filename + yolo_class_path = "" + yolo_weights_path = "" else: - yolo_model_def="./models/detectors/yolo/config/yolov3.cfg" - yolo_class_path="./models/detectors/yolo/data/coco.names" - yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights" + raise ValueError('Unsopported YOLO version.') model = SimpleHRNet( hrnet_c, @@ -61,10 +76,12 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo multiperson=not single_person, return_bounding_boxes=not disable_tracking, max_batch_size=max_batch_size, + yolo_version=yolo_version, yolo_model_def=yolo_model_def, yolo_class_path=yolo_class_path, yolo_weights_path=yolo_weights_path, - device=device + device=device, + enable_tensorrt=enable_tensorrt ) if not disable_tracking: @@ -72,13 +89,15 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo prev_pts = None prev_person_ids = None next_person_id = 0 - + t_start = time.time() while True: t = time.time() if filename is not None or disable_vidgear: ret, frame = video.read() if not ret: + t_end = time.time() + print("\n Total Time: ", t_end - t_start) break if rotation_code is not None: frame = cv2.rotate(frame, rotation_code) @@ -118,8 +137,11 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo points_color_palette='gist_rainbow', skeleton_color_palette='jet', points_palette_samples=10) + # for box in boxes: + # cv2.rectangle(frame,(box[0],box[1]),(box[2],box[3]),(255,255,255),2) + fps = 1. / (time.time() - t) - print('\rframerate: %f fps' % fps, end='') + print('\rframerate: %f fps, for %d person(s) ' % (fps,len(pts)), end='') if has_display: cv2.imshow('frame.png', frame) @@ -162,8 +184,13 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo help="disable the multiperson detection (YOLOv3 or an equivalen detector is required for" "multiperson detection)", action="store_true") + parser.add_argument("--yolo_version", + help="Use the specified version of YOLO. Supported versions: `v3` (default), `v5`.", + type=str, default="v3") parser.add_argument("--use_tiny_yolo", - help="Use YOLOv3-tiny in place of YOLOv3 (faster person detection). Ignored if --single_person", + help="Use YOLOv3-tiny in place of YOLOv3 (faster person detection) if `yolo_version` is `v3`." + "Use YOLOv5n(ano) in place of YOLOv5m(edium) if `yolo_version` is `v5`." + "Ignored if --single_person", action="store_true") parser.add_argument("--disable_tracking", help="disable the skeleton tracking and temporal smoothing functionality", @@ -174,12 +201,18 @@ def main(camera_id, filename, hrnet_m, hrnet_c, hrnet_j, hrnet_weights, hrnet_jo action="store_true") # see https://pypi.org/project/vidgear/ parser.add_argument("--save_video", help="save output frames into a video.", action="store_true") parser.add_argument("--video_format", help="fourcc video format. Common formats: `MJPG`, `XVID`, `X264`." - "See http://www.fourcc.org/codecs.php", type=str, default='MJPG') + "See http://www.fourcc.org/codecs.php", type=str, default='MJPG') parser.add_argument("--video_framerate", help="video framerate", type=float, default=30) parser.add_argument("--device", help="device to be used (default: cuda, if available)." "Set to `cuda` to use all available GPUs (default); " "set to `cuda:IDS` to use one or more specific GPUs " "(e.g. `cuda:0` `cuda:1,2`); " "set to `cpu` to run on cpu.", type=str, default=None) + parser.add_argument("--enable_tensorrt", + help="Enables tensorrt inference for HRnet. If enabled, a `.engine` file is expected as " + "weights (`--hrnet_weights`). This option should be used only after the HRNet engine " + "file has been generated using the script `scripts/export-tensorrt-model.py`.", + action='store_true') + args = parser.parse_args() main(**args.__dict__) diff --git a/testing/Test.py b/testing/Test.py index 263e3b5..6d3ad60 100644 --- a/testing/Test.py +++ b/testing/Test.py @@ -10,7 +10,7 @@ from misc.checkpoint import load_checkpoint from misc.utils import flip_tensor, flip_back from misc.visualization import save_images -from models.hrnet import HRNet +from models_.hrnet import HRNet class Test(object): diff --git a/training/Train.py b/training/Train.py index badbd03..2c4ea00 100644 --- a/training/Train.py +++ b/training/Train.py @@ -13,7 +13,7 @@ from misc.checkpoint import save_checkpoint, load_checkpoint from misc.utils import flip_tensor, flip_back from misc.visualization import save_images -from models.hrnet import HRNet +from models_.hrnet import HRNet class Train(object):