From a2195aa3789403e24ca44b941e0f14c3b49655ce Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Wed, 22 May 2024 14:24:52 +0200 Subject: [PATCH 1/8] initial commit adding support for `from_lmm` and specifically for PaliGemma --- supervision/detection/core.py | 47 ++++++++++++++ supervision/detection/lmm.py | 62 +++++++++++++++++++ test/detection/test_lmm.py | 113 ++++++++++++++++++++++++++++++++++ 3 files changed, 222 insertions(+) create mode 100644 supervision/detection/lmm.py create mode 100644 test/detection/test_lmm.py diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 0ba9e4f42..d6a04efb6 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -7,6 +7,7 @@ import numpy as np from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES +from supervision.detection.lmm import LMM, validate_lmm_and_kwargs, from_paligemma from supervision.detection.utils import ( box_non_max_suppression, calculate_masks_centroids, @@ -805,6 +806,52 @@ def from_paddledet(cls, paddledet_result) -> Detections: class_id=paddledet_result["bbox"][:, 0].astype(int), ) + @classmethod + def from_lmm(cls, lmm: Union[LMM, str], result: str, **kwargs) -> Detections: + """ + Creates a Detections object from the given result string based on the specified + Large Multimodal Model (LMM). + + Args: + lmm (Union[LMM, str]): The type of LMM (Large Multimodal Model) to use. + result (str): The result string containing the detection data. + **kwargs: Additional keyword arguments required by the specified LMM. + + Returns: + Detections: A new Detections object. + + Raises: + ValueError: If the LMM is invalid, required arguments are missing, or + disallowed arguments are provided. + ValueError: If the specified LMM is not supported. + + Examples: + ```python + import supervision as sv + + paligemma_result = " cat" + detections = sv.Detections.from_lmm( + sv.LMM.PALIGEMMA, + paligemma_result, + resolution_wh=(1000, 1000), + classes=['cat', 'dog'] + ) + detections.xyxy + # array([[250., 250., 750., 750.]]) + + detections.class_id + # array([0]) + ``` + """ + lmm = validate_lmm_and_kwargs(lmm, kwargs) + + if lmm == LMM.PALIGEMMA: + xyxy, class_id, class_name = from_paligemma(result, **kwargs) + data = {CLASS_NAME_DATA_FIELD: class_name} + return cls(xyxy=xyxy, class_id=class_id, data=data) + + raise ValueError(f"Unsupported LMM: {lmm}") + @classmethod def empty(cls) -> Detections: """ diff --git a/supervision/detection/lmm.py b/supervision/detection/lmm.py new file mode 100644 index 000000000..1c4b90dc1 --- /dev/null +++ b/supervision/detection/lmm.py @@ -0,0 +1,62 @@ +import re +import numpy as np +from enum import Enum +from typing import Dict, List, Tuple, Optional, Union, Any + + +class LMM(Enum): + PALIGEMMA = 'paligemma' + + +REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = { + LMM.PALIGEMMA: ['resolution_wh'] +} + +ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = { + LMM.PALIGEMMA: ['resolution_wh', 'classes'] +} + + +def validate_lmm_and_kwargs(lmm: Union[LMM, str], kwargs: Dict[str, Any]) -> LMM: + if isinstance(lmm, str): + try: + lmm = LMM(lmm.lower()) + except ValueError: + raise ValueError( + f"Invalid lmm value: {lmm}. Must be one of {[e.value for e in LMM]}" + ) + + required_args = REQUIRED_ARGUMENTS.get(lmm, []) + for arg in required_args: + if arg not in kwargs: + raise ValueError(f"Missing required argument: {arg}") + + allowed_args = ALLOWED_ARGUMENTS.get(lmm, []) + for arg in kwargs: + if arg not in allowed_args: + raise ValueError(f"Argument {arg} is not allowed for {lmm.name}") + + return lmm + + +def from_paligemma( + result: str, + resolution_wh: Tuple[int, int], + classes: Optional[List[str]] = None +) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]: + w, h = resolution_wh + pattern = re.compile( + r'(?) (\w+)') + matches = pattern.findall(result) + matches = np.array(matches) if matches else np.empty((0, 5)) + + xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4] + xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h]) + class_id = None + + if classes is not None: + mask = np.array([name in classes for name in class_name]) + xyxy, class_name = xyxy[mask], class_name[mask] + class_id = np.array([classes.index(name) for name in class_name]) + + return xyxy, class_id, class_name.astype(np.dtype('U')) diff --git a/test/detection/test_lmm.py b/test/detection/test_lmm.py new file mode 100644 index 000000000..b7f7c5b4c --- /dev/null +++ b/test/detection/test_lmm.py @@ -0,0 +1,113 @@ +import numpy as np +from typing import Tuple, Optional, List + +import pytest + +from supervision.detection.lmm import from_paligemma + + +@pytest.mark.parametrize( + "result, resolution_wh, classes, expected_results", + [ + ( + "", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + ), # empty response + ( + "\n", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + ), # new line response + ( + "the quick brown fox jumps over the lazy dog.", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + ), # response with no location + ( + " cat", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + ), # response with missing location + ( + " cat", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + ), # response with extra location + ( + "", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + ), # response with no class + ( + " catt", + (1000, 1000), + ['cat', 'dog'], + (np.empty((0, 4)), np.empty(0), np.empty(0).astype(np.dtype('U'))) + ), # response with invalid class + ( + " cat", + (1000, 1000), + None, + ( + np.array([[250., 250., 750., 750.]]), + None, + np.array(['cat']).astype(np.dtype('U')) + ) + ), # correct response; no classes + ( + " cat ;", + (1000, 1000), + ['cat', 'dog'], + ( + np.array([[250., 250., 750., 750.]]), + np.array([0]), + np.array(['cat']).astype(np.dtype('U')) + ) + ), # correct response; with classes + ( + " cat ; cat", + (1000, 1000), + ['cat', 'dog'], + ( + np.array([[250., 250., 750., 750.]]), + np.array([0]), + np.array(['cat']).astype(np.dtype('U')) + ) + ), # partially correct response; with classes + ( + " cat ; cat", + (1000, 1000), + ['cat', 'dog'], + ( + np.array([[250., 250., 750., 750.]]), + np.array([0]), + np.array(['cat']).astype(np.dtype('U')) + ) + ), # partially correct response; with classes + ] +) +def test_from_paligemma( + result: str, + resolution_wh: Tuple[int, int], + classes: Optional[List[str]], + expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray] +) -> None: + result = from_paligemma(result=result, resolution_wh=resolution_wh, classes=classes) + + print(result[0].dtype) + print(expected_results[0].dtype) + # print(result[1]) + # print(expected_results[1]) + print(result[2].dtype) + print(expected_results[2].dtype) + + np.testing.assert_array_equal(result[0], expected_results[0]) + np.testing.assert_array_equal(result[1], expected_results[1]) + np.testing.assert_array_equal(result[2], expected_results[2]) From 36a73f7d4912a56ab289d38d22b17b1538bd39ed Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Wed, 22 May 2024 14:26:13 +0200 Subject: [PATCH 2/8] clean up --- test/detection/test_lmm.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/detection/test_lmm.py b/test/detection/test_lmm.py index b7f7c5b4c..f8ea91ef9 100644 --- a/test/detection/test_lmm.py +++ b/test/detection/test_lmm.py @@ -100,14 +100,6 @@ def test_from_paligemma( expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray] ) -> None: result = from_paligemma(result=result, resolution_wh=resolution_wh, classes=classes) - - print(result[0].dtype) - print(expected_results[0].dtype) - # print(result[1]) - # print(expected_results[1]) - print(result[2].dtype) - print(expected_results[2].dtype) - np.testing.assert_array_equal(result[0], expected_results[0]) np.testing.assert_array_equal(result[1], expected_results[1]) np.testing.assert_array_equal(result[2], expected_results[2]) From 7eb918282cf81c0c30e80ddad279265fe35528a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 May 2024 12:27:53 +0000 Subject: [PATCH 3/8] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/core.py | 2 +- supervision/detection/lmm.py | 24 +++++++--------- test/detection/test_lmm.py | 54 +++++++++++++++++------------------ 3 files changed, 38 insertions(+), 42 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index d6a04efb6..e85998173 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -7,7 +7,7 @@ import numpy as np from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES -from supervision.detection.lmm import LMM, validate_lmm_and_kwargs, from_paligemma +from supervision.detection.lmm import LMM, from_paligemma, validate_lmm_and_kwargs from supervision.detection.utils import ( box_non_max_suppression, calculate_masks_centroids, diff --git a/supervision/detection/lmm.py b/supervision/detection/lmm.py index 1c4b90dc1..679213288 100644 --- a/supervision/detection/lmm.py +++ b/supervision/detection/lmm.py @@ -1,20 +1,17 @@ import re -import numpy as np from enum import Enum -from typing import Dict, List, Tuple, Optional, Union, Any +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np class LMM(Enum): - PALIGEMMA = 'paligemma' + PALIGEMMA = "paligemma" -REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = { - LMM.PALIGEMMA: ['resolution_wh'] -} +REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh"]} -ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = { - LMM.PALIGEMMA: ['resolution_wh', 'classes'] -} +ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh", "classes"]} def validate_lmm_and_kwargs(lmm: Union[LMM, str], kwargs: Dict[str, Any]) -> LMM: @@ -40,13 +37,12 @@ def validate_lmm_and_kwargs(lmm: Union[LMM, str], kwargs: Dict[str, Any]) -> LMM def from_paligemma( - result: str, - resolution_wh: Tuple[int, int], - classes: Optional[List[str]] = None + result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None ) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]: w, h = resolution_wh pattern = re.compile( - r'(?) (\w+)') + r"(?) (\w+)" + ) matches = pattern.findall(result) matches = np.array(matches) if matches else np.empty((0, 5)) @@ -59,4 +55,4 @@ def from_paligemma( xyxy, class_name = xyxy[mask], class_name[mask] class_id = np.array([classes.index(name) for name in class_name]) - return xyxy, class_id, class_name.astype(np.dtype('U')) + return xyxy, class_id, class_name.astype(np.dtype("U")) diff --git a/test/detection/test_lmm.py b/test/detection/test_lmm.py index f8ea91ef9..5066a7a3e 100644 --- a/test/detection/test_lmm.py +++ b/test/detection/test_lmm.py @@ -1,6 +1,6 @@ -import numpy as np -from typing import Tuple, Optional, List +from typing import List, Optional, Tuple +import numpy as np import pytest from supervision.detection.lmm import from_paligemma @@ -13,91 +13,91 @@ "", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype("U"))), ), # empty response ( "\n", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype("U"))), ), # new line response ( "the quick brown fox jumps over the lazy dog.", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype("U"))), ), # response with no location ( " cat", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype("U"))), ), # response with missing location ( " cat", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype("U"))), ), # response with extra location ( "", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(np.dtype("U"))), ), # response with no class ( " catt", (1000, 1000), - ['cat', 'dog'], - (np.empty((0, 4)), np.empty(0), np.empty(0).astype(np.dtype('U'))) + ["cat", "dog"], + (np.empty((0, 4)), np.empty(0), np.empty(0).astype(np.dtype("U"))), ), # response with invalid class ( " cat", (1000, 1000), None, ( - np.array([[250., 250., 750., 750.]]), + np.array([[250.0, 250.0, 750.0, 750.0]]), None, - np.array(['cat']).astype(np.dtype('U')) - ) + np.array(["cat"]).astype(np.dtype("U")), + ), ), # correct response; no classes ( " cat ;", (1000, 1000), - ['cat', 'dog'], + ["cat", "dog"], ( - np.array([[250., 250., 750., 750.]]), + np.array([[250.0, 250.0, 750.0, 750.0]]), np.array([0]), - np.array(['cat']).astype(np.dtype('U')) - ) + np.array(["cat"]).astype(np.dtype("U")), + ), ), # correct response; with classes ( " cat ; cat", (1000, 1000), - ['cat', 'dog'], + ["cat", "dog"], ( - np.array([[250., 250., 750., 750.]]), + np.array([[250.0, 250.0, 750.0, 750.0]]), np.array([0]), - np.array(['cat']).astype(np.dtype('U')) - ) + np.array(["cat"]).astype(np.dtype("U")), + ), ), # partially correct response; with classes ( " cat ; cat", (1000, 1000), - ['cat', 'dog'], + ["cat", "dog"], ( - np.array([[250., 250., 750., 750.]]), + np.array([[250.0, 250.0, 750.0, 750.0]]), np.array([0]), - np.array(['cat']).astype(np.dtype('U')) - ) + np.array(["cat"]).astype(np.dtype("U")), + ), ), # partially correct response; with classes - ] + ], ) def test_from_paligemma( result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]], - expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray] + expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray], ) -> None: result = from_paligemma(result=result, resolution_wh=resolution_wh, classes=classes) np.testing.assert_array_equal(result[0], expected_results[0]) From b81c5e84758487494f90ad67805f9ddba4564ebe Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Wed, 22 May 2024 16:28:24 +0200 Subject: [PATCH 4/8] update to allow multi-word class names --- supervision/detection/lmm.py | 5 ++-- test/detection/test_lmm.py | 45 +++++++++++++++++++++++++++--------- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/supervision/detection/lmm.py b/supervision/detection/lmm.py index 1c4b90dc1..9cd7434ec 100644 --- a/supervision/detection/lmm.py +++ b/supervision/detection/lmm.py @@ -46,12 +46,13 @@ def from_paligemma( ) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]: w, h = resolution_wh pattern = re.compile( - r'(?) (\w+)') + r'(?) ([\w\s]+)') matches = pattern.findall(result) matches = np.array(matches) if matches else np.empty((0, 5)) xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4] xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h]) + class_name = np.char.strip(class_name.astype(str)) class_id = None if classes is not None: @@ -59,4 +60,4 @@ def from_paligemma( xyxy, class_name = xyxy[mask], class_name[mask] class_id = np.array([classes.index(name) for name in class_name]) - return xyxy, class_id, class_name.astype(np.dtype('U')) + return xyxy, class_id, class_name diff --git a/test/detection/test_lmm.py b/test/detection/test_lmm.py index f8ea91ef9..91840d887 100644 --- a/test/detection/test_lmm.py +++ b/test/detection/test_lmm.py @@ -13,43 +13,43 @@ "", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(str)) ), # empty response ( "\n", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(str)) ), # new line response ( "the quick brown fox jumps over the lazy dog.", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(str)) ), # response with no location ( " cat", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(str)) ), # response with missing location ( " cat", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(str)) ), # response with extra location ( "", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), None, np.empty(0).astype(str)) ), # response with no class ( " catt", (1000, 1000), ['cat', 'dog'], - (np.empty((0, 4)), np.empty(0), np.empty(0).astype(np.dtype('U'))) + (np.empty((0, 4)), np.empty(0), np.empty(0).astype(str)) ), # response with invalid class ( " cat", @@ -58,7 +58,17 @@ ( np.array([[250., 250., 750., 750.]]), None, - np.array(['cat']).astype(np.dtype('U')) + np.array(['cat']).astype(str) + ) + ), # correct response; no classes + ( + " black cat", + (1000, 1000), + None, + ( + np.array([[250., 250., 750., 750.]]), + None, + np.array(['black cat']).astype(np.dtype('U')) ) ), # correct response; no classes ( @@ -68,7 +78,20 @@ ( np.array([[250., 250., 750., 750.]]), np.array([0]), - np.array(['cat']).astype(np.dtype('U')) + np.array(['cat']).astype(str) + ) + ), # correct response; with classes + ( + " cat ; dog", + (1000, 1000), + ['cat', 'dog'], + ( + np.array([ + [250., 250., 750., 750.], + [250., 250., 750., 750.] + ]), + np.array([0, 1]), + np.array(['cat', 'dog']).astype(np.dtype('U')) ) ), # correct response; with classes ( @@ -78,7 +101,7 @@ ( np.array([[250., 250., 750., 750.]]), np.array([0]), - np.array(['cat']).astype(np.dtype('U')) + np.array(['cat']).astype(str) ) ), # partially correct response; with classes ( @@ -88,7 +111,7 @@ ( np.array([[250., 250., 750., 750.]]), np.array([0]), - np.array(['cat']).astype(np.dtype('U')) + np.array(['cat']).astype(str) ) ), # partially correct response; with classes ] From bf43d6566b8bfa9ac6c317bceb312c414268a60b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 May 2024 14:29:20 +0000 Subject: [PATCH 5/8] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/lmm.py | 22 +++++------ test/detection/test_lmm.py | 71 +++++++++++++++++------------------- 2 files changed, 43 insertions(+), 50 deletions(-) diff --git a/supervision/detection/lmm.py b/supervision/detection/lmm.py index 9cd7434ec..3660fb68e 100644 --- a/supervision/detection/lmm.py +++ b/supervision/detection/lmm.py @@ -1,20 +1,17 @@ import re -import numpy as np from enum import Enum -from typing import Dict, List, Tuple, Optional, Union, Any +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np class LMM(Enum): - PALIGEMMA = 'paligemma' + PALIGEMMA = "paligemma" -REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = { - LMM.PALIGEMMA: ['resolution_wh'] -} +REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh"]} -ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = { - LMM.PALIGEMMA: ['resolution_wh', 'classes'] -} +ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh", "classes"]} def validate_lmm_and_kwargs(lmm: Union[LMM, str], kwargs: Dict[str, Any]) -> LMM: @@ -40,13 +37,12 @@ def validate_lmm_and_kwargs(lmm: Union[LMM, str], kwargs: Dict[str, Any]) -> LMM def from_paligemma( - result: str, - resolution_wh: Tuple[int, int], - classes: Optional[List[str]] = None + result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None ) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]: w, h = resolution_wh pattern = re.compile( - r'(?) ([\w\s]+)') + r"(?) ([\w\s]+)" + ) matches = pattern.findall(result) matches = np.array(matches) if matches else np.empty((0, 5)) diff --git a/test/detection/test_lmm.py b/test/detection/test_lmm.py index 91840d887..5b4f31ba2 100644 --- a/test/detection/test_lmm.py +++ b/test/detection/test_lmm.py @@ -1,6 +1,6 @@ -import numpy as np -from typing import Tuple, Optional, List +from typing import List, Optional, Tuple +import numpy as np import pytest from supervision.detection.lmm import from_paligemma @@ -13,114 +13,111 @@ "", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)) + (np.empty((0, 4)), None, np.empty(0).astype(str)), ), # empty response ( "\n", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)) + (np.empty((0, 4)), None, np.empty(0).astype(str)), ), # new line response ( "the quick brown fox jumps over the lazy dog.", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)) + (np.empty((0, 4)), None, np.empty(0).astype(str)), ), # response with no location ( " cat", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)) + (np.empty((0, 4)), None, np.empty(0).astype(str)), ), # response with missing location ( " cat", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)) + (np.empty((0, 4)), None, np.empty(0).astype(str)), ), # response with extra location ( "", (1000, 1000), None, - (np.empty((0, 4)), None, np.empty(0).astype(str)) + (np.empty((0, 4)), None, np.empty(0).astype(str)), ), # response with no class ( " catt", (1000, 1000), - ['cat', 'dog'], - (np.empty((0, 4)), np.empty(0), np.empty(0).astype(str)) + ["cat", "dog"], + (np.empty((0, 4)), np.empty(0), np.empty(0).astype(str)), ), # response with invalid class ( " cat", (1000, 1000), None, ( - np.array([[250., 250., 750., 750.]]), + np.array([[250.0, 250.0, 750.0, 750.0]]), None, - np.array(['cat']).astype(str) - ) + np.array(["cat"]).astype(str), + ), ), # correct response; no classes ( " black cat", (1000, 1000), None, ( - np.array([[250., 250., 750., 750.]]), + np.array([[250.0, 250.0, 750.0, 750.0]]), None, - np.array(['black cat']).astype(np.dtype('U')) - ) + np.array(["black cat"]).astype(np.dtype("U")), + ), ), # correct response; no classes ( " cat ;", (1000, 1000), - ['cat', 'dog'], + ["cat", "dog"], ( - np.array([[250., 250., 750., 750.]]), + np.array([[250.0, 250.0, 750.0, 750.0]]), np.array([0]), - np.array(['cat']).astype(str) - ) + np.array(["cat"]).astype(str), + ), ), # correct response; with classes ( " cat ; dog", (1000, 1000), - ['cat', 'dog'], + ["cat", "dog"], ( - np.array([ - [250., 250., 750., 750.], - [250., 250., 750., 750.] - ]), + np.array([[250.0, 250.0, 750.0, 750.0], [250.0, 250.0, 750.0, 750.0]]), np.array([0, 1]), - np.array(['cat', 'dog']).astype(np.dtype('U')) - ) + np.array(["cat", "dog"]).astype(np.dtype("U")), + ), ), # correct response; with classes ( " cat ; cat", (1000, 1000), - ['cat', 'dog'], + ["cat", "dog"], ( - np.array([[250., 250., 750., 750.]]), + np.array([[250.0, 250.0, 750.0, 750.0]]), np.array([0]), - np.array(['cat']).astype(str) - ) + np.array(["cat"]).astype(str), + ), ), # partially correct response; with classes ( " cat ; cat", (1000, 1000), - ['cat', 'dog'], + ["cat", "dog"], ( - np.array([[250., 250., 750., 750.]]), + np.array([[250.0, 250.0, 750.0, 750.0]]), np.array([0]), - np.array(['cat']).astype(str) - ) + np.array(["cat"]).astype(str), + ), ), # partially correct response; with classes - ] + ], ) def test_from_paligemma( result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]], - expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray] + expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray], ) -> None: result = from_paligemma(result=result, resolution_wh=resolution_wh, classes=classes) np.testing.assert_array_equal(result[0], expected_results[0]) From 07c36c6223a49bbc496503ec3db34f2b0ac775f8 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Wed, 22 May 2024 17:47:52 +0200 Subject: [PATCH 6/8] make linter happy --- test/detection/test_lmm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/detection/test_lmm.py b/test/detection/test_lmm.py index 5b4f31ba2..b6232bd85 100644 --- a/test/detection/test_lmm.py +++ b/test/detection/test_lmm.py @@ -82,7 +82,7 @@ ), ), # correct response; with classes ( - " cat ; dog", + " cat ; dog", # noqa: E501 (1000, 1000), ["cat", "dog"], ( @@ -92,7 +92,7 @@ ), ), # correct response; with classes ( - " cat ; cat", + " cat ; cat", # noqa: E501 (1000, 1000), ["cat", "dog"], ( @@ -102,7 +102,7 @@ ), ), # partially correct response; with classes ( - " cat ; cat", + " cat ; cat", # noqa: E501 (1000, 1000), ["cat", "dog"], ( From 0115ef8b7d40b9a242ed3799820b897b43b5da7e Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Thu, 23 May 2024 09:11:47 +0200 Subject: [PATCH 7/8] small fix when `mask` is empty --- supervision/detection/lmm.py | 2 +- test/detection/test_lmm.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/supervision/detection/lmm.py b/supervision/detection/lmm.py index 3660fb68e..0278fc004 100644 --- a/supervision/detection/lmm.py +++ b/supervision/detection/lmm.py @@ -52,7 +52,7 @@ def from_paligemma( class_id = None if classes is not None: - mask = np.array([name in classes for name in class_name]) + mask = np.array([name in classes for name in class_name]).astype(bool) xyxy, class_name = xyxy[mask], class_name[mask] class_id = np.array([classes.index(name) for name in class_name]) diff --git a/test/detection/test_lmm.py b/test/detection/test_lmm.py index b6232bd85..e20b947d3 100644 --- a/test/detection/test_lmm.py +++ b/test/detection/test_lmm.py @@ -15,6 +15,12 @@ None, (np.empty((0, 4)), None, np.empty(0).astype(str)), ), # empty response + ( + "", + (1000, 1000), + ['cat', 'dog'], + (np.empty((0, 4)), None, np.empty(0).astype(str)), + ), # empty response with classes ( "\n", (1000, 1000), From ad2220bc1da2e018d1ce08685359eb02ab3c5bd4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 May 2024 07:12:17 +0000 Subject: [PATCH 8/8] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/detection/test_lmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/detection/test_lmm.py b/test/detection/test_lmm.py index e20b947d3..129aa44b4 100644 --- a/test/detection/test_lmm.py +++ b/test/detection/test_lmm.py @@ -18,7 +18,7 @@ ( "", (1000, 1000), - ['cat', 'dog'], + ["cat", "dog"], (np.empty((0, 4)), None, np.empty(0).astype(str)), ), # empty response with classes (