Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit adding support for from_lmm and specifically for Pal… #1221

Merged
merged 9 commits into from
May 24, 2024
47 changes: 47 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES
from supervision.detection.lmm import LMM, from_paligemma, validate_lmm_and_kwargs
from supervision.detection.utils import (
box_non_max_suppression,
calculate_masks_centroids,
Expand Down Expand Up @@ -805,6 +806,52 @@ def from_paddledet(cls, paddledet_result) -> Detections:
class_id=paddledet_result["bbox"][:, 0].astype(int),
)

@classmethod
def from_lmm(cls, lmm: Union[LMM, str], result: str, **kwargs) -> Detections:
"""
Creates a Detections object from the given result string based on the specified
Large Multimodal Model (LMM).

Args:
lmm (Union[LMM, str]): The type of LMM (Large Multimodal Model) to use.
result (str): The result string containing the detection data.
**kwargs: Additional keyword arguments required by the specified LMM.

Returns:
Detections: A new Detections object.

Raises:
ValueError: If the LMM is invalid, required arguments are missing, or
disallowed arguments are provided.
ValueError: If the specified LMM is not supported.

Examples:
```python
import supervision as sv

paligemma_result = "<loc0256><loc0256><loc0768><loc0768> cat"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be also useful for models like KOSMOS-2 so best to make it very general (this is a trend with VLMs these days) https://huggingface.co/docs/transformers/en/model_doc/kosmos-2#transformers.Kosmos2ForConditionalGeneration.forward.example

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what I want to do is one from_lmm function, providing separate dedicated parsers for each model. :)

detections = sv.Detections.from_lmm(
sv.LMM.PALIGEMMA,
paligemma_result,
resolution_wh=(1000, 1000),
classes=['cat', 'dog']
)
detections.xyxy
# array([[250., 250., 750., 750.]])

detections.class_id
# array([0])
```
"""
lmm = validate_lmm_and_kwargs(lmm, kwargs)

if lmm == LMM.PALIGEMMA:
xyxy, class_id, class_name = from_paligemma(result, **kwargs)
data = {CLASS_NAME_DATA_FIELD: class_name}
return cls(xyxy=xyxy, class_id=class_id, data=data)

raise ValueError(f"Unsupported LMM: {lmm}")

@classmethod
def empty(cls) -> Detections:
"""
Expand Down
59 changes: 59 additions & 0 deletions supervision/detection/lmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import re
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np


class LMM(Enum):
PALIGEMMA = "paligemma"


REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh"]}

ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = {LMM.PALIGEMMA: ["resolution_wh", "classes"]}


def validate_lmm_and_kwargs(lmm: Union[LMM, str], kwargs: Dict[str, Any]) -> LMM:
if isinstance(lmm, str):
try:
lmm = LMM(lmm.lower())
except ValueError:
raise ValueError(
f"Invalid lmm value: {lmm}. Must be one of {[e.value for e in LMM]}"
)

required_args = REQUIRED_ARGUMENTS.get(lmm, [])
for arg in required_args:
if arg not in kwargs:
raise ValueError(f"Missing required argument: {arg}")

allowed_args = ALLOWED_ARGUMENTS.get(lmm, [])
for arg in kwargs:
if arg not in allowed_args:
raise ValueError(f"Argument {arg} is not allowed for {lmm.name}")

return lmm


def from_paligemma(
result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
w, h = resolution_wh
pattern = re.compile(
r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s]+)"
)
matches = pattern.findall(result)
matches = np.array(matches) if matches else np.empty((0, 5))

xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4]
xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h])
class_name = np.char.strip(class_name.astype(str))
class_id = None

if classes is not None:
mask = np.array([name in classes for name in class_name]).astype(bool)
xyxy, class_name = xyxy[mask], class_name[mask]
class_id = np.array([classes.index(name) for name in class_name])

return xyxy, class_id, class_name
131 changes: 131 additions & 0 deletions test/detection/test_lmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import List, Optional, Tuple

import numpy as np
import pytest

from supervision.detection.lmm import from_paligemma


@pytest.mark.parametrize(
"result, resolution_wh, classes, expected_results",
[
(
"",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # empty response
(
"",
(1000, 1000),
["cat", "dog"],
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # empty response with classes
(
"\n",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # new line response
(
"the quick brown fox jumps over the lazy dog.",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # response with no location
(
"<loc0256><loc0768><loc0768> cat",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # response with missing location
(
"<loc0256><loc0256><loc0768><loc0768><loc0768> cat",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # response with extra location
(
"<loc0256><loc0256><loc0768><loc0768>",
(1000, 1000),
None,
(np.empty((0, 4)), None, np.empty(0).astype(str)),
), # response with no class
(
"<loc0256><loc0256><loc0768><loc0768> catt",
(1000, 1000),
["cat", "dog"],
(np.empty((0, 4)), np.empty(0), np.empty(0).astype(str)),
), # response with invalid class
(
"<loc0256><loc0256><loc0768><loc0768> cat",
(1000, 1000),
None,
(
np.array([[250.0, 250.0, 750.0, 750.0]]),
None,
np.array(["cat"]).astype(str),
),
), # correct response; no classes
(
"<loc0256><loc0256><loc0768><loc0768> black cat",
(1000, 1000),
None,
(
np.array([[250.0, 250.0, 750.0, 750.0]]),
None,
np.array(["black cat"]).astype(np.dtype("U")),
),
), # correct response; no classes
(
"<loc0256><loc0256><loc0768><loc0768> cat ;",
(1000, 1000),
["cat", "dog"],
(
np.array([[250.0, 250.0, 750.0, 750.0]]),
np.array([0]),
np.array(["cat"]).astype(str),
),
), # correct response; with classes
(
"<loc0256><loc0256><loc0768><loc0768> cat ; <loc0256><loc0256><loc0768><loc0768> dog", # noqa: E501
(1000, 1000),
["cat", "dog"],
(
np.array([[250.0, 250.0, 750.0, 750.0], [250.0, 250.0, 750.0, 750.0]]),
np.array([0, 1]),
np.array(["cat", "dog"]).astype(np.dtype("U")),
),
), # correct response; with classes
(
"<loc0256><loc0256><loc0768><loc0768> cat ; <loc0256><loc0256><loc0768> cat", # noqa: E501
(1000, 1000),
["cat", "dog"],
(
np.array([[250.0, 250.0, 750.0, 750.0]]),
np.array([0]),
np.array(["cat"]).astype(str),
),
), # partially correct response; with classes
(
"<loc0256><loc0256><loc0768><loc0768> cat ; <loc0256><loc0256><loc0768><loc0768><loc0768> cat", # noqa: E501
(1000, 1000),
["cat", "dog"],
(
np.array([[250.0, 250.0, 750.0, 750.0]]),
np.array([0]),
np.array(["cat"]).astype(str),
),
), # partially correct response; with classes
],
)
def test_from_paligemma(
result: str,
resolution_wh: Tuple[int, int],
classes: Optional[List[str]],
expected_results: Tuple[np.ndarray, Optional[np.ndarray], np.ndarray],
) -> None:
result = from_paligemma(result=result, resolution_wh=resolution_wh, classes=classes)
np.testing.assert_array_equal(result[0], expected_results[0])
np.testing.assert_array_equal(result[1], expected_results[1])
np.testing.assert_array_equal(result[2], expected_results[2])