From fe717132d9f7221fb8a33436d500c94e7a71651b Mon Sep 17 00:00:00 2001 From: raghav Date: Thu, 1 Feb 2024 19:51:32 +0530 Subject: [PATCH 1/2] Adjusting LineZoneAnnotator to display per-class counts in annotations. #790 --- supervision/detection/line_counter.py | 67 +++++++++++++++++---------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/supervision/detection/line_counter.py b/supervision/detection/line_counter.py index ae662d116..59fc4d092 100644 --- a/supervision/detection/line_counter.py +++ b/supervision/detection/line_counter.py @@ -51,8 +51,10 @@ def __init__( self.vector = Vector(start=start, end=end) self.limits = self.calculate_region_of_interest_limits(vector=self.vector) self.tracker_state: Dict[str, bool] = {} - self.in_count: int = 0 - self.out_count: int = 0 + #self.in_count: int = 0 + #self.out_count: int = 0 + self.class_in_count: Dict[int, int] = {} # Per-class in count + self.class_out_count: Dict[int, int] = {} # Per-class out count self.triggering_anchors = triggering_anchors @staticmethod @@ -123,7 +125,8 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]: for i, tracker_id in enumerate(detections.tracker_id): if tracker_id is None: continue - + + class_label = detections.class_labels[i] # To get class label box_anchors = [Point(x=x, y=y) for x, y in all_anchors[:, i, :]] in_limits = all( @@ -156,10 +159,21 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]: if tracker_state: self.in_count += 1 crossed_in[i] = True + + # Update per-class in count + if class_label not in self.class_in_count: + self.class_in_count[class_label] = 0 + self.class_in_count[class_label] += 1 + else: self.out_count += 1 crossed_out[i] = True + # Update per-class out count + if class_label not in self.class_out_count: + self.class_out_count[class_label] = 0 + self.class_out_count[class_label] += 1 + return crossed_in, crossed_out @@ -284,28 +298,31 @@ def annotate(self, frame: np.ndarray, line_counter: LineZone) -> np.ndarray: ) if self.display_in_count: - in_text = ( - f"{self.custom_in_text}: {line_counter.in_count}" - if self.custom_in_text is not None - else f"in: {line_counter.in_count}" - ) - self._annotate_count( - frame=frame, - center_text_anchor=text_anchor.center, - text=in_text, - is_in_count=True, - ) + for class_label, count in line_counter.class_in_count.items(): + in_text = ( + f"{self.custom_in_text}: {count} - Class {class_label}" + if self.custom_in_text is not None + else f"in: {count} - Class {class_label}" + ) + self._annotate_count( + frame=frame, + center_text_anchor=text_anchor.center, + text=in_text, + is_in_count=True, + ) if self.display_out_count: - out_text = ( - f"{self.custom_out_text}: {line_counter.out_count}" - if self.custom_out_text is not None - else f"out: {line_counter.out_count}" - ) - self._annotate_count( - frame=frame, - center_text_anchor=text_anchor.center, - text=out_text, - is_in_count=False, - ) + for class_label, count in line_counter.class_out_count.items(): + out_text = ( + f"{self.custom_out_text}: {count} - Class {class_label}" + if self.custom_out_text is not None + else f"out: {count} - Class {class_label}" + ) + self._annotate_count( + frame=frame, + center_text_anchor=text_anchor.center, + text=out_text, + is_in_count=False, + ) + return frame From 92f780b0688db523b8f5f91c95edaf00ae966c1b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 1 Feb 2024 14:55:55 +0000 Subject: [PATCH 2/2] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/line_counter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/supervision/detection/line_counter.py b/supervision/detection/line_counter.py index 59fc4d092..283c901dc 100644 --- a/supervision/detection/line_counter.py +++ b/supervision/detection/line_counter.py @@ -51,8 +51,8 @@ def __init__( self.vector = Vector(start=start, end=end) self.limits = self.calculate_region_of_interest_limits(vector=self.vector) self.tracker_state: Dict[str, bool] = {} - #self.in_count: int = 0 - #self.out_count: int = 0 + # self.in_count: int = 0 + # self.out_count: int = 0 self.class_in_count: Dict[int, int] = {} # Per-class in count self.class_out_count: Dict[int, int] = {} # Per-class out count self.triggering_anchors = triggering_anchors @@ -125,7 +125,7 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]: for i, tracker_id in enumerate(detections.tracker_id): if tracker_id is None: continue - + class_label = detections.class_labels[i] # To get class label box_anchors = [Point(x=x, y=y) for x, y in all_anchors[:, i, :]] @@ -324,5 +324,5 @@ def annotate(self, frame: np.ndarray, line_counter: LineZone) -> np.ndarray: text=out_text, is_in_count=False, ) - + return frame