diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 0ba9e4f42..e85998173 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -7,6 +7,7 @@ import numpy as np from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES +from supervision.detection.lmm import LMM, from_paligemma, validate_lmm_and_kwargs from supervision.detection.utils import ( box_non_max_suppression, calculate_masks_centroids, @@ -805,6 +806,52 @@ def from_paddledet(cls, paddledet_result) -> Detections: class_id=paddledet_result["bbox"][:, 0].astype(int), ) + @classmethod + def from_lmm(cls, lmm: Union[LMM, str], result: str, **kwargs) -> Detections: + """ + Creates a Detections object from the given result string based on the specified + Large Multimodal Model (LMM). + + Args: + lmm (Union[LMM, str]): The type of LMM (Large Multimodal Model) to use. + result (str): The result string containing the detection data. + **kwargs: Additional keyword arguments required by the specified LMM. + + Returns: + Detections: A new Detections object. + + Raises: + ValueError: If the LMM is invalid, required arguments are missing, or + disallowed arguments are provided. + ValueError: If the specified LMM is not supported. + + Examples: + ```python + import supervision as sv + + paligemma_result = " cat" + detections = sv.Detections.from_lmm( + sv.LMM.PALIGEMMA, + paligemma_result, + resolution_wh=(1000, 1000), + classes=['cat', 'dog'] + ) + detections.xyxy + # array([[250., 250., 750., 750.]]) + + detections.class_id + # array([0]) + ``` + """ + lmm = validate_lmm_and_kwargs(lmm, kwargs) + + if lmm == LMM.PALIGEMMA: + xyxy, class_id, class_name = from_paligemma(result, **kwargs) + data = {CLASS_NAME_DATA_FIELD: class_name} + return cls(xyxy=xyxy, class_id=class_id, data=data) + + raise ValueError(f"Unsupported LMM: {lmm}") + @classmethod def empty(cls) -> Detections: """ diff --git a/supervision/detection/lmm.py b/supervision/detection/lmm.py new file mode 100644 index 000000000..0278fc004 --- /dev/null +++ b/supervision/detection/lmm.py @@ -0,0 +1,59 @@ +import re +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + + +class LMM(Enum): + PALIGEMMA = "paligemma" + + +REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh"]} + +ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh", "classes"]} + + +def validate_lmm_and_kwargs(lmm: Union[LMM, str], kwargs: Dict[str, Any]) -> LMM: + if isinstance(lmm, str): + try: + lmm = LMM(lmm.lower()) + except ValueError: + raise ValueError( + f"Invalid lmm value: {lmm}. Must be one of {[e.value for e in LMM]}" + ) + + required_args = REQUIRED_ARGUMENTS.get(lmm, []) + for arg in required_args: + if arg not in kwargs: + raise ValueError(f"Missing required argument: {arg}") + + allowed_args = ALLOWED_ARGUMENTS.get(lmm, []) + for arg in kwargs: + if arg not in allowed_args: + raise ValueError(f"Argument {arg} is not allowed for {lmm.name}") + + return lmm + + +def from_paligemma( + result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None +) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]: + w, h = resolution_wh + pattern = re.compile( + r"(?) ([\w\s]+)" + ) + matches = pattern.findall(result) + matches = np.array(matches) if matches else np.empty((0, 5)) + + xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4] + xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h]) + class_name = np.char.strip(class_name.astype(str)) + class_id = None + + if classes is not None: + mask = np.array([name in classes for name in class_name]).astype(bool) + xyxy, class_name = xyxy[mask], class_name[mask] + class_id = np.array([classes.index(name) for name in class_name]) + + return xyxy, class_id, class_name diff --git a/test/detection/test_lmm.py b/test/detection/test_lmm.py new file mode 100644 index 000000000..129aa44b4 --- /dev/null +++ b/test/detection/test_lmm.py @@ -0,0 +1,131 @@ +from typing import List, Optional, Tuple + +import numpy as np +import pytest + +from supervision.detection.lmm import from_paligemma + + +@pytest.mark.parametrize( + "result, resolution_wh, classes, expected_results", + [ + ( + "", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(str)), + ), # empty response + ( + "", + (1000, 1000), + ["cat", "dog"], + (np.empty((0, 4)), None, np.empty(0).astype(str)), + ), # empty response with classes + ( + "\n", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(str)), + ), # new line response + ( + "the quick brown fox jumps over the lazy dog.", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(str)), + ), # response with no location + ( + " cat", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(str)), + ), # response with missing location + ( + " cat", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(str)), + ), # response with extra location + ( + "", + (1000, 1000), + None, + (np.empty((0, 4)), None, np.empty(0).astype(str)), + ), # response with no class + ( + " catt", + (1000, 1000), + ["cat", "dog"], + (np.empty((0, 4)), np.empty(0), np.empty(0).astype(str)), + ), # response with invalid class + ( + " cat", + (1000, 1000), + None, + ( + np.array([[250.0, 250.0, 750.0, 750.0]]), + None, + np.array(["cat"]).astype(str), + ), + ), # correct response; no classes + ( + " black cat", + (1000, 1000), + None, + ( + np.array([[250.0, 250.0, 750.0, 750.0]]), + None, + np.array(["black cat"]).astype(np.dtype("U")), + ), + ), # correct response; no classes + ( + " cat ;", + (1000, 1000), + ["cat", "dog"], + ( + np.array([[250.0, 250.0, 750.0, 750.0]]), + np.array([0]), + np.array(["cat"]).astype(str), + ), + ), # correct response; with classes + ( + " cat ; dog", # noqa: E501 + (1000, 1000), + ["cat", "dog"], + ( + np.array([[250.0, 250.0, 750.0, 750.0], [250.0, 250.0, 750.0, 750.0]]), + np.array([0, 1]), + np.array(["cat", "dog"]).astype(np.dtype("U")), + ), + ), # correct response; with classes + ( + " cat ; cat", # noqa: E501 + (1000, 1000), + ["cat", "dog"], + ( + np.array([[250.0, 250.0, 750.0, 750.0]]), + np.array([0]), + np.array(["cat"]).astype(str), + ), + ), # partially correct response; with classes + ( + " cat ; cat", # noqa: E501 + (1000, 1000), + ["cat", "dog"], + ( + np.array([[250.0, 250.0, 750.0, 750.0]]), + np.array([0]), + np.array(["cat"]).astype(str), + ), + ), # partially correct response; with classes + ], +) +def test_from_paligemma( + result: str, + resolution_wh: Tuple[int, int], + classes: Optional[List[str]], + expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray], +) -> None: + result = from_paligemma(result=result, resolution_wh=resolution_wh, classes=classes) + np.testing.assert_array_equal(result[0], expected_results[0]) + np.testing.assert_array_equal(result[1], expected_results[1]) + np.testing.assert_array_equal(result[2], expected_results[2])