From 94c10a669485d138267b633e30ca1e577f2b1a21 Mon Sep 17 00:00:00 2001 From: "Yukai Yang (Alexis)" Date: Sat, 22 Jan 2022 01:42:21 -0500 Subject: [PATCH] Multi-class tracking fix (#247) * Use binary search to split class bboxes * Remove fastmath * Fix pep8 * Use default argument for bisect * Revert find_split_indices and fix imports --- fastmot/mot.py | 17 +++++++++++++++-- fastmot/utils/numba.py | 13 +++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/fastmot/mot.py b/fastmot/mot.py index d756c671..ae43f90f 100644 --- a/fastmot/mot.py +++ b/fastmot/mot.py @@ -2,6 +2,7 @@ from enum import Enum import logging import numpy as np +import numba as nb import cv2 from .detector import SSDDetector, YOLODetector, PublicDetector @@ -9,7 +10,7 @@ from .tracker import MultiTracker from .utils import Profiler from .utils.visualization import Visualizer -from .utils.numba import find_split_indices +from .utils.numba import bisect_right LOGGER = logging.getLogger(__name__) @@ -143,7 +144,8 @@ def step(self, frame): detections = self.detector.postprocess() with Profiler('extract'): - cls_bboxes = np.split(detections.tlbr, find_split_indices(detections.label)) + cls_bboxes = self._split_bboxes_by_cls(detections.tlbr, detections.label, + self.class_ids) for extractor, bboxes in zip(self.extractors, cls_bboxes): extractor.extract_async(frame, bboxes) @@ -175,6 +177,17 @@ def print_timing_info(): f"{Profiler.get_avg_millis('extract'):>6.3f} ms") LOGGER.debug(f"{'association time:':<37}{Profiler.get_avg_millis('assoc'):>6.3f} ms") + @staticmethod + @nb.njit(cache=True) + def _split_bboxes_by_cls(bboxes, labels, class_ids): + cls_bboxes = [] + begin = 0 + for cls_id in class_ids: + end = bisect_right(labels, cls_id, begin) + cls_bboxes.append(bboxes[begin:end]) + begin = end + return cls_bboxes + def _draw(self, frame, detections): visible_tracks = list(self.visible_tracks()) self.visualizer.render(frame, visible_tracks, detections, self.tracker.klt_bboxes.values(), diff --git a/fastmot/utils/numba.py b/fastmot/utils/numba.py index dfe588f8..7a798789 100644 --- a/fastmot/utils/numba.py +++ b/fastmot/utils/numba.py @@ -39,6 +39,19 @@ def mask_area(mask): return count +@nb.njit(fastmath=True, cache=True) +def bisect_right(arr, val, left=0): + """Utility to search a value in a sorted array.""" + right = len(arr) + while left < right: + mid = left + (right - left) // 2 + if arr[mid] >= val: + left = mid + 1 + else: + right = mid + return left + + @nb.njit(fastmath=True, cache=True) def find_split_indices(arr): """Utility to find indices of unique elements in sorted array."""