diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 6f964ac..e569db0 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -21,6 +21,8 @@ jobs: version: '3.11' runs-on: ${{ matrix.os }} + env: + CUDA_VISIBLE_DEVICES: "" steps: - name: Checkout diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index 59de92a..fdbb8bf 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -17,6 +17,8 @@ jobs: version: ['3.10', '3.11'] runs-on: ${{ matrix.os }} + env: + CUDA_VISIBLE_DEVICES: "" steps: - name: Checkout diff --git a/datadreamer/dataset_annotation/owlv2_annotator.py b/datadreamer/dataset_annotation/owlv2_annotator.py index 9da41b4..1e22dea 100644 --- a/datadreamer/dataset_annotation/owlv2_annotator.py +++ b/datadreamer/dataset_annotation/owlv2_annotator.py @@ -3,6 +3,8 @@ import logging from typing import Dict, List, Tuple +import matplotlib.patches as patches +import matplotlib.pyplot as plt import numpy as np import PIL import torch @@ -80,43 +82,116 @@ def _init_processor(self) -> Owlv2Processor: "google/owlv2-base-patch16-ensemble", do_pad=False, do_resize=False ) - def _generate_annotations( + def _generate_annotations_from_text( self, images: List[PIL.Image.Image], prompts: List[str], conf_threshold: float = 0.1, ) -> List[Dict[str, torch.Tensor]]: - """Generates annotations for the given images and prompts. + """Generates annotations for the given images and text prompts. Args: images: The images to be annotated. - prompts: Prompts to guide the annotation. + prompts: The text prompts to guide the annotation. conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1. Returns: - List[Dict[str, torch.Tensor]]: The annotations for the given images and prompts. + List[Dict[str, torch.Tensor]]: The annotations for the given images and text prompts. """ - n = len(images) - batched_prompts = [prompts] * n + + batched_prompts = [prompts] * len(images) target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(self.device) - # resize the images to the model's input size img_size = (1008, 1008) if self.size == "large" else (960, 960) - images = [images[i].resize(img_size) for i in range(n)] + images = [img.resize(img_size) for img in images] + inputs = self.processor( - text=batched_prompts, images=images, + text=batched_prompts, return_tensors="pt", padding="max_length", ).to(self.device) + with torch.no_grad(): outputs = self.model(**inputs) + preds = self.processor.post_process_object_detection( outputs=outputs, target_sizes=target_sizes, threshold=conf_threshold ) return preds + def _generate_annotations_from_image( + self, + images: List[PIL.Image.Image], + query_images: List[PIL.Image.Image], + conf_threshold: float = 0.1, + ) -> List[Dict[str, torch.Tensor]]: + """Generates annotations for the given images and query images. + + Args: + images: The images to be annotated. + query_images: The query images to guide the annotation. One query image is expected per target image to be queried. + conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1. + + Returns: + List[Dict[str, torch.Tensor]]: The annotations for the given images and query images. + """ + + if len(query_images) != len(images) and len(query_images) != 1: + raise ValueError( + "The number of query images must be either 1 or the same as the number of target images." + ) + + target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(self.device) + + inputs = self.processor( + images=images, + query_images=query_images, + return_tensors="pt", + do_resize=True, + ).to(self.device) + + with torch.no_grad(): + outputs = self.model.image_guided_detection(**inputs) + + preds = self.processor.post_process_image_guided_detection( + outputs=outputs, + target_sizes=target_sizes, + threshold=conf_threshold, + ) + + return preds + + def _generate_annotations( + self, + images: List[PIL.Image.Image], + prompts: List[str] | List[PIL.Image.Image], + conf_threshold: float = 0.1, + ) -> List[Dict[str, torch.Tensor]]: + """Generates annotations for the given images and prompts. + + Args: + images: The images to be annotated. + prompts: Either text prompts (List[str]) or a list of query images (List[PIL.Image.Image]). + conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1. + + Returns: + List[Dict[str, torch.Tensor]]: The annotations for the given images and prompts. + """ + if isinstance(prompts, list) and all(isinstance(p, str) for p in prompts): + return self._generate_annotations_from_text(images, prompts, conf_threshold) + elif isinstance(prompts, list) and all( + isinstance(p, PIL.Image.Image) for p in prompts + ): + return self._generate_annotations_from_image( + images, prompts, conf_threshold + ) + else: + raise ValueError( + "Invalid prompts: Expected List[str] or List[PIL.Image.Image]" + ) + def _get_annotations( self, pred: Dict[str, torch.Tensor], @@ -157,6 +232,9 @@ def _get_annotations( if use_tta: boxes[:, [0, 2]] = img_width - boxes[:, [2, 0]] + if labels is None: + labels = torch.zeros(scores.shape, dtype=torch.int64) + return boxes, scores, labels def _correct_bboxes_misalignment( @@ -186,7 +264,7 @@ def _correct_bboxes_misalignment( def annotate_batch( self, images: List[PIL.Image.Image], - prompts: List[str], + prompts: List[str] | List[PIL.Image.Image], conf_threshold: float = 0.1, iou_threshold: float = 0.2, use_tta: bool = False, @@ -324,10 +402,42 @@ def release(self, empty_cuda_cache: bool = False) -> None: import requests from PIL import Image + # Text-driven annotation url = "https://ultralytics.com/images/bus.jpg" im = Image.open(requests.get(url, stream=True).raw) annotator = OWLv2Annotator(device="cpu", size="base") final_boxes, final_scores, final_labels = annotator.annotate_batch( [im], ["bus", "person"] ) + + # Image-driven annotation + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + query_url = "http://images.cocodataset.org/val2017/000000058111.jpg" + query_image = Image.open(requests.get(query_url, stream=True).raw) + + final_boxes, final_scores, final_labels = annotator.annotate_batch( + [im], [query_image], conf_threshold=0.9 + ) + print(final_boxes, final_scores, final_labels) + + fig, ax = plt.subplots(1) + ax.imshow(im) + for box, score, label in zip(final_boxes[0], final_scores[0], final_labels[0]): + x1, y1, x2, y2 = box + width, height = x2 - x1, y2 - y1 + rect = patches.Rectangle( + (x1, y1), width, height, linewidth=2, edgecolor="r", facecolor="none" + ) + ax.add_patch(rect) + + plt.text( + x1, + y1, + f"{label} {score:.2f}", + bbox=dict(facecolor="yellow", alpha=0.5), + ) + + plt.savefig("test_image_guided.png") + annotator.release() diff --git a/requirements.txt b/requirements.txt index cdd8a50..fdaaac7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=2.0.0 +torch>=2.0.0,<=2.5.1 torchvision>=0.16.0 transformers>=4.45.2 diffusers>=0.31.0 diff --git a/tests/core_tests/unittests/test_annotators.py b/tests/core_tests/unittests/test_annotators.py index eb5c986..8ea1502 100644 --- a/tests/core_tests/unittests/test_annotators.py +++ b/tests/core_tests/unittests/test_annotators.py @@ -16,13 +16,26 @@ total_disk_space = psutil.disk_usage("/").total / (1024**3) -def _check_owlv2_annotator(device: str, size: str = "base"): - url = "https://ultralytics.com/images/bus.jpg" - im = Image.open(requests.get(url, stream=True).raw) +def _check_owlv2_annotator( + device: str, size: str = "base", use_text_prompts: bool = True +): annotator = OWLv2Annotator(device=device, size=size) - final_boxes, final_scores, final_labels = annotator.annotate_batch( - [im], ["bus", "people"] - ) + + if use_text_prompts: + url = "https://ultralytics.com/images/bus.jpg" + im = Image.open(requests.get(url, stream=True).raw) + final_boxes, final_scores, final_labels = annotator.annotate_batch( + [im], ["bus", "people"] + ) + else: + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + query_url = "http://images.cocodataset.org/val2017/000000058111.jpg" + query_image = Image.open(requests.get(query_url, stream=True).raw) + annotator = OWLv2Annotator(device=device, size=size) + final_boxes, final_scores, final_labels = annotator.annotate_batch( + [im], [query_image], conf_threshold=0.9 + ) # Assert that the boxes, scores and labels are tensors assert isinstance(final_boxes, list) and len(final_boxes) == 1 assert isinstance(final_scores, list) and len(final_scores) == 1 @@ -45,16 +58,32 @@ def _check_owlv2_annotator(device: str, size: str = "base"): not torch.cuda.is_available() or total_disk_space < 16, reason="Test requires GPU and 16GB of HDD", ) -def test_cuda_owlv2_annotator(): - _check_owlv2_annotator("cuda") +def test_cuda_owlv2_annotator_text(): + _check_owlv2_annotator("cuda", use_text_prompts=True) + + +@pytest.mark.skipif( + total_disk_space < 16, + reason="Test requires at least 16GB of HDD", +) +def test_cpu_owlv2_annotator_text(): + _check_owlv2_annotator("cpu", use_text_prompts=True) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or total_disk_space < 16, + reason="Test requires GPU and 16GB of HDD", +) +def test_cuda_owlv2_annotator_image(): + _check_owlv2_annotator("cuda", use_text_prompts=False) @pytest.mark.skipif( total_disk_space < 16, reason="Test requires at least 16GB of HDD", ) -def test_cpu_owlv2_annotator(): - _check_owlv2_annotator("cpu") +def test_cpu_owlv2_annotator_image(): + _check_owlv2_annotator("cpu", use_text_prompts=False) def _check_aimv2_annotator(device: str):