Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image-driven object detection using OWLv2 #78

Merged
merged 4 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 120 additions & 10 deletions datadreamer/dataset_annotation/owlv2_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
from typing import Dict, List, Tuple

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
Expand Down Expand Up @@ -80,43 +82,116 @@ def _init_processor(self) -> Owlv2Processor:
"google/owlv2-base-patch16-ensemble", do_pad=False, do_resize=False
)

def _generate_annotations(
def _generate_annotations_from_text(
self,
images: List[PIL.Image.Image],
prompts: List[str],
conf_threshold: float = 0.1,
) -> List[Dict[str, torch.Tensor]]:
"""Generates annotations for the given images and prompts.
"""Generates annotations for the given images and text prompts.

Args:
images: The images to be annotated.
prompts: Prompts to guide the annotation.
prompts: The text prompts to guide the annotation.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.

Returns:
List[Dict[str, torch.Tensor]]: The annotations for the given images and prompts.
List[Dict[str, torch.Tensor]]: The annotations for the given images and text prompts.
"""
n = len(images)
batched_prompts = [prompts] * n

batched_prompts = [prompts] * len(images)
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(self.device)

# resize the images to the model's input size
img_size = (1008, 1008) if self.size == "large" else (960, 960)
images = [images[i].resize(img_size) for i in range(n)]
images = [img.resize(img_size) for img in images]

inputs = self.processor(
text=batched_prompts,
images=images,
text=batched_prompts,
return_tensors="pt",
padding="max_length",
).to(self.device)

with torch.no_grad():
outputs = self.model(**inputs)

preds = self.processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=conf_threshold
)

return preds

def _generate_annotations_from_image(
self,
images: List[PIL.Image.Image],
query_images: List[PIL.Image.Image],
conf_threshold: float = 0.1,
) -> List[Dict[str, torch.Tensor]]:
"""Generates annotations for the given images and query images.

Args:
images: The images to be annotated.
query_images: The query images to guide the annotation. One query image is expected per target image to be queried.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.

Returns:
List[Dict[str, torch.Tensor]]: The annotations for the given images and query images.
"""

if len(query_images) != len(images) and len(query_images) != 1:
raise ValueError(
"The number of query images must be either 1 or the same as the number of target images."
)

target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(self.device)

inputs = self.processor(
images=images,
query_images=query_images,
return_tensors="pt",
do_resize=True,
).to(self.device)

with torch.no_grad():
outputs = self.model.image_guided_detection(**inputs)

preds = self.processor.post_process_image_guided_detection(
outputs=outputs,
target_sizes=target_sizes,
threshold=conf_threshold,
)

return preds

def _generate_annotations(
self,
images: List[PIL.Image.Image],
prompts: List[str] | List[PIL.Image.Image],
conf_threshold: float = 0.1,
) -> List[Dict[str, torch.Tensor]]:
"""Generates annotations for the given images and prompts.

Args:
images: The images to be annotated.
prompts: Either text prompts (List[str]) or a list of query images (List[PIL.Image.Image]).
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.

Returns:
List[Dict[str, torch.Tensor]]: The annotations for the given images and prompts.
"""
if isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
return self._generate_annotations_from_text(images, prompts, conf_threshold)
elif isinstance(prompts, list) and all(
isinstance(p, PIL.Image.Image) for p in prompts
):
return self._generate_annotations_from_image(
images, prompts, conf_threshold
)
else:
raise ValueError(
"Invalid prompts: Expected List[str] or List[PIL.Image.Image]"
)

def _get_annotations(
self,
pred: Dict[str, torch.Tensor],
Expand Down Expand Up @@ -157,6 +232,9 @@ def _get_annotations(
if use_tta:
boxes[:, [0, 2]] = img_width - boxes[:, [2, 0]]

if labels is None:
labels = torch.zeros(scores.shape, dtype=torch.int64)

return boxes, scores, labels

def _correct_bboxes_misalignment(
Expand Down Expand Up @@ -186,7 +264,7 @@ def _correct_bboxes_misalignment(
def annotate_batch(
self,
images: List[PIL.Image.Image],
prompts: List[str],
prompts: List[str] | List[PIL.Image.Image],
conf_threshold: float = 0.1,
iou_threshold: float = 0.2,
use_tta: bool = False,
Expand Down Expand Up @@ -324,10 +402,42 @@ def release(self, empty_cuda_cache: bool = False) -> None:
import requests
from PIL import Image

# Text-driven annotation
url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = OWLv2Annotator(device="cpu", size="base")
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], ["bus", "person"]
)

# Image-driven annotation
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
query_url = "http://images.cocodataset.org/val2017/000000058111.jpg"
query_image = Image.open(requests.get(query_url, stream=True).raw)

final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], [query_image], conf_threshold=0.9
)
print(final_boxes, final_scores, final_labels)

fig, ax = plt.subplots(1)
ax.imshow(im)
for box, score, label in zip(final_boxes[0], final_scores[0], final_labels[0]):
x1, y1, x2, y2 = box
width, height = x2 - x1, y2 - y1
rect = patches.Rectangle(
(x1, y1), width, height, linewidth=2, edgecolor="r", facecolor="none"
)
ax.add_patch(rect)

plt.text(
x1,
y1,
f"{label} {score:.2f}",
bbox=dict(facecolor="yellow", alpha=0.5),
)

plt.savefig("test_image_guided.png")

annotator.release()
49 changes: 39 additions & 10 deletions tests/core_tests/unittests/test_annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,26 @@
total_disk_space = psutil.disk_usage("/").total / (1024**3)


def _check_owlv2_annotator(device: str, size: str = "base"):
url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
def _check_owlv2_annotator(
device: str, size: str = "base", use_text_prompts: bool = True
):
annotator = OWLv2Annotator(device=device, size=size)
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], ["bus", "people"]
)

if use_text_prompts:
url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], ["bus", "people"]
)
else:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
query_url = "http://images.cocodataset.org/val2017/000000058111.jpg"
query_image = Image.open(requests.get(query_url, stream=True).raw)
annotator = OWLv2Annotator(device=device, size=size)
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], [query_image], conf_threshold=0.9
)
# Assert that the boxes, scores and labels are tensors
assert isinstance(final_boxes, list) and len(final_boxes) == 1
assert isinstance(final_scores, list) and len(final_scores) == 1
Expand All @@ -45,16 +58,32 @@ def _check_owlv2_annotator(device: str, size: str = "base"):
not torch.cuda.is_available() or total_disk_space < 16,
reason="Test requires GPU and 16GB of HDD",
)
def test_cuda_owlv2_annotator():
_check_owlv2_annotator("cuda")
def test_cuda_owlv2_annotator_text():
_check_owlv2_annotator("cuda", use_text_prompts=True)


@pytest.mark.skipif(
total_disk_space < 16,
reason="Test requires at least 16GB of HDD",
)
def test_cpu_owlv2_annotator_text():
_check_owlv2_annotator("cpu", use_text_prompts=True)


@pytest.mark.skipif(
not torch.cuda.is_available() or total_disk_space < 16,
reason="Test requires GPU and 16GB of HDD",
)
def test_cuda_owlv2_annotator_image():
_check_owlv2_annotator("cuda", use_text_prompts=False)


@pytest.mark.skipif(
total_disk_space < 16,
reason="Test requires at least 16GB of HDD",
)
def test_cpu_owlv2_annotator():
_check_owlv2_annotator("cpu")
def test_cpu_owlv2_annotator_image():
_check_owlv2_annotator("cpu", use_text_prompts=False)


def _check_aimv2_annotator(device: str):
Expand Down
Loading