From a280cdd793c63d17e33812b6845ba3b18b8b5890 Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Tue, 24 Jan 2023 18:49:29 +0300 Subject: [PATCH] Fix MaskFormerImageProcessor.post_process_instance_segmentation (#21256) * fix instance segmentation post processing * add Mask2FormerImageProcessor --- docs/source/en/model_doc/mask2former.mdx | 13 +- src/transformers/__init__.py | 2 + .../models/auto/image_processing_auto.py | 2 +- .../models/mask2former/__init__.py | 15 + ..._original_pytorch_checkpoint_to_pytorch.py | 8 +- .../image_processing_mask2former.py | 1149 +++++++++++++++++ .../mask2former/modeling_mask2former.py | 9 +- .../maskformer/image_processing_maskformer.py | 99 +- .../utils/dummy_vision_objects.py | 7 + .../test_image_processing_mask2former.py | 611 +++++++++ .../mask2former/test_modeling_mask2former.py | 4 +- .../test_image_processing_maskformer.py | 28 + 12 files changed, 1899 insertions(+), 48 deletions(-) create mode 100644 src/transformers/models/mask2former/image_processing_mask2former.py create mode 100644 tests/models/mask2former/test_image_processing_mask2former.py diff --git a/docs/source/en/model_doc/mask2former.mdx b/docs/source/en/model_doc/mask2former.mdx index 71412315e14e69..dd4a5d52332cb3 100644 --- a/docs/source/en/model_doc/mask2former.mdx +++ b/docs/source/en/model_doc/mask2former.mdx @@ -22,8 +22,8 @@ The abstract from the paper is the following: of semantics defines a task. While only the semantics of each task differ, current research focuses on designing specialized architectures for each task. We present Masked-attention Mask Transformer (Mask2Former), a new architecture capable of addressing any image segmentation task (panoptic, instance or semantic). Its key components include masked attention, which extracts localized features by constraining cross-attention within predicted mask regions. In addition to reducing the research effort by at least three times, it outperforms the best specialized architectures by a significant margin on four popular datasets. Most notably, Mask2Former sets a new state-of-the-art for panoptic segmentation (57.8 PQ on COCO), instance segmentation (50.1 AP on COCO) and semantic segmentation (57.7 mIoU on ADE20K).* Tips: -- Mask2Former uses the same preprocessing and postprocessing steps as [MaskFormer](maskformer). Use [`MaskFormerImageProcessor`] or [`AutoImageProcessor`] to prepare images and optional targets for the model. -- To get the final segmentation, depending on the task, you can call [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`Mask2FormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together. +- Mask2Former uses the same preprocessing and postprocessing steps as [MaskFormer](maskformer). Use [`Mask2FormerImageProcessor`] or [`AutoImageProcessor`] to prepare images and optional targets for the model. +- To get the final segmentation, depending on the task, you can call [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or [`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or [`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`]. All three tasks can be solved using [`Mask2FormerForUniversalSegmentation`] output, panoptic segmentation accepts an optional `label_ids_to_fuse` argument to fuse instances of the target object/s (e.g. sky) together. This model was contributed by [Shivalika Singh](https://huggingface.co/shivi) and [Alara Dirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/facebookresearch/Mask2Former). @@ -55,3 +55,12 @@ The resource should ideally demonstrate something new instead of duplicating an [[autodoc]] Mask2FormerForUniversalSegmentation - forward + +## Mask2FormerImageProcessor + +[[autodoc]] Mask2FormerImageProcessor + - preprocess + - encode_inputs + - post_process_semantic_segmentation + - post_process_instance_segmentation + - post_process_panoptic_segmentation \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b59523ed5013bd..8edf896e95e8be 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -799,6 +799,7 @@ _import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"]) _import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"]) _import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"]) + _import_structure["models.mask2former"].append("Mask2FormerImageProcessor") _import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"]) _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"]) _import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"]) @@ -4152,6 +4153,7 @@ from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2ImageProcessor from .models.layoutlmv3 import LayoutLMv3FeatureExtractor, LayoutLMv3ImageProcessor from .models.levit import LevitFeatureExtractor, LevitImageProcessor + from .models.mask2former import Mask2FormerImageProcessor from .models.maskformer import MaskFormerFeatureExtractor, MaskFormerImageProcessor from .models.mobilenet_v1 import MobileNetV1FeatureExtractor, MobileNetV1ImageProcessor from .models.mobilenet_v2 import MobileNetV2FeatureExtractor, MobileNetV2ImageProcessor diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index ca0621a7841c3b..2c6bd221d174b8 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -62,7 +62,7 @@ ("layoutlmv2", "LayoutLMv2ImageProcessor"), ("layoutlmv3", "LayoutLMv3ImageProcessor"), ("levit", "LevitImageProcessor"), - ("mask2former", "MaskFormerImageProcessor"), + ("mask2former", "Mask2FormerImageProcessor"), ("maskformer", "MaskFormerImageProcessor"), ("mobilenet_v1", "MobileNetV1ImageProcessor"), ("mobilenet_v2", "MobileNetV2ImageProcessor"), diff --git a/src/transformers/models/mask2former/__init__.py b/src/transformers/models/mask2former/__init__.py index 995533c0102de1..4b9adf042c29c0 100644 --- a/src/transformers/models/mask2former/__init__.py +++ b/src/transformers/models/mask2former/__init__.py @@ -27,6 +27,13 @@ ], } +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_mask2former"] = ["Mask2FormerImageProcessor"] try: if not is_torch_available(): @@ -44,6 +51,14 @@ if TYPE_CHECKING: from .configuration_mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_mask2former import Mask2FormerImageProcessor + try: if not is_torch_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py index 19bba580783f58..b53e695db7ebc2 100644 --- a/src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py @@ -33,8 +33,8 @@ from transformers import ( Mask2FormerConfig, Mask2FormerForUniversalSegmentation, + Mask2FormerImageProcessor, Mask2FormerModel, - MaskFormerImageProcessor, SwinConfig, ) from transformers.models.mask2former.modeling_mask2former import ( @@ -193,11 +193,11 @@ def __call__(self, original_config: object) -> Mask2FormerConfig: class OriginalMask2FormerConfigToFeatureExtractorConverter: - def __call__(self, original_config: object) -> MaskFormerImageProcessor: + def __call__(self, original_config: object) -> Mask2FormerImageProcessor: model = original_config.MODEL model_input = original_config.INPUT - return MaskFormerImageProcessor( + return Mask2FormerImageProcessor( image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(), image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(), size=model_input.MIN_SIZE_TEST, @@ -847,7 +847,7 @@ def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object def test( original_model, our_model: Mask2FormerForUniversalSegmentation, - feature_extractor: MaskFormerImageProcessor, + feature_extractor: Mask2FormerImageProcessor, tolerance: float, ): with torch.no_grad(): diff --git a/src/transformers/models/mask2former/image_processing_mask2former.py b/src/transformers/models/mask2former/image_processing_mask2former.py new file mode 100644 index 00000000000000..01c267cdd48c53 --- /dev/null +++ b/src/transformers/models/mask2former/image_processing_mask2former.py @@ -0,0 +1,1149 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Mask2Former.""" + +import math +import warnings +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import numpy as np + +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import ( + PaddingMode, + get_resize_output_image_size, + normalize, + pad, + rescale, + resize, + to_channel_dimension_format, + to_numpy_array, +) +from transformers.image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_batched, + valid_images, +) +from transformers.utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + TensorType, + is_torch_available, + is_torch_tensor, + logging, +) + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + from torch import nn + + +# Copied from transformers.models.detr.image_processing_detr.max_across_indices +def max_across_indices(values: Iterable[Any]) -> List[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +# Copied from transformers.models.detr.image_processing_detr.get_max_height_width +def get_max_height_width(images: List[np.ndarray]) -> List[int]: + """ + Get the maximum height and width across all images in a batch. + """ + input_channel_dimension = infer_channel_dimension_format(images[0]) + + if input_channel_dimension == ChannelDimension.FIRST: + _, max_height, max_width = max_across_indices([img.shape for img in images]) + elif input_channel_dimension == ChannelDimension.LAST: + max_height, max_width, _ = max_across_indices([img.shape for img in images]) + else: + raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}") + return (max_height, max_width) + + +# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask +def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`Tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +# Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle +def binary_mask_to_rle(mask): + """ + Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + mask (`torch.Tensor` or `numpy.array`): + A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target + segment_id or class_id. + Returns: + `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE + format. + """ + if is_torch_tensor(mask): + mask = mask.numpy() + + pixels = mask.flatten() + pixels = np.concatenate([[0], pixels, [0]]) + runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 + runs[1::2] -= runs[::2] + return [x for x in runs] + + +# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle +def convert_segmentation_to_rle(segmentation): + """ + Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + segmentation (`torch.Tensor` or `numpy.array`): + A segmentation map of shape `(height, width)` where each value denotes a segment or class id. + Returns: + `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id. + """ + segment_ids = torch.unique(segmentation) + + run_length_encodings = [] + for idx in segment_ids: + mask = torch.where(segmentation == idx, 1, 0) + rle = binary_mask_to_rle(mask) + run_length_encodings.append(rle) + + return run_length_encodings + + +# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects +def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): + """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and + `labels`. + + Args: + masks (`torch.Tensor`): + A tensor of shape `(num_queries, height, width)`. + scores (`torch.Tensor`): + A tensor of shape `(num_queries)`. + labels (`torch.Tensor`): + A tensor of shape `(num_queries)`. + object_mask_threshold (`float`): + A number between 0 and 1 used to binarize the masks. + Raises: + `ValueError`: Raised when the first dimension doesn't match in all input tensors. + Returns: + `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region + < `object_mask_threshold`. + """ + if not (masks.shape[0] == scores.shape[0] == labels.shape[0]): + raise ValueError("mask, scores and labels must have the same shape!") + + to_keep = labels.ne(num_labels) & (scores > object_mask_threshold) + + return masks[to_keep], scores[to_keep], labels[to_keep] + + +# Copied from transformers.models.detr.image_processing_detr.check_segment_validity +def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8): + # Get the mask associated with the k class + mask_k = mask_labels == k + mask_k_area = mask_k.sum() + + # Compute the area of all the stuff in query k + original_area = (mask_probs[k] >= mask_threshold).sum() + mask_exists = mask_k_area > 0 and original_area > 0 + + # Eliminate disconnected tiny segments + if mask_exists: + area_ratio = mask_k_area / original_area + if not area_ratio.item() > overlap_mask_area_threshold: + mask_exists = False + + return mask_exists, mask_k + + +# Copied from transformers.models.detr.image_processing_detr.compute_segments +def compute_segments( + mask_probs, + pred_scores, + pred_labels, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[Set[int]] = None, + target_size: Tuple[int, int] = None, +): + height = mask_probs.shape[1] if target_size is None else target_size[0] + width = mask_probs.shape[2] if target_size is None else target_size[1] + + segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device) + segments: List[Dict] = [] + + if target_size is not None: + mask_probs = nn.functional.interpolate( + mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False + )[0] + + current_segment_id = 0 + + # Weigh each mask by its prediction score + mask_probs *= pred_scores.view(-1, 1, 1) + mask_labels = mask_probs.argmax(0) # [height, width] + + # Keep track of instances of each class + stuff_memory_list: Dict[str, int] = {} + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + should_fuse = pred_class in label_ids_to_fuse + + # Check if mask exists and large enough to be a segment + mask_exists, mask_k = check_segment_validity( + mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold + ) + + if mask_exists: + if pred_class in stuff_memory_list: + current_segment_id = stuff_memory_list[pred_class] + else: + current_segment_id += 1 + + # Add current object segment to final segmentation map + segmentation[mask_k] = current_segment_id + segment_score = round(pred_scores[k].item(), 6) + segments.append( + { + "id": current_segment_id, + "label_id": pred_class, + "was_fused": should_fuse, + "score": segment_score, + } + ) + if should_fuse: + stuff_memory_list[pred_class] = current_segment_id + + return segmentation, segments + + +# TODO: (Amy) Move to image_transforms +def convert_segmentation_map_to_binary_masks( + segmentation_map: "np.ndarray", + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + ignore_index: Optional[int] = None, + reduce_labels: bool = False, +): + if reduce_labels and ignore_index is None: + raise ValueError("If `reduce_labels` is True, `ignore_index` must be provided.") + + if reduce_labels: + segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1) + + # Get unique ids (class or instance ids based on input) + all_labels = np.unique(segmentation_map) + + # Drop background label if applicable + if ignore_index is not None: + all_labels = all_labels[all_labels != ignore_index] + + # Generate a binary mask for each object instance + binary_masks = [(segmentation_map == i) for i in all_labels] + binary_masks = np.stack(binary_masks, axis=0) # (num_labels, height, width) + + # Convert instance ids to class ids + if instance_id_to_semantic_id is not None: + labels = np.zeros(all_labels.shape[0]) + + for label in all_labels: + class_id = instance_id_to_semantic_id[label + 1 if reduce_labels else label] + labels[all_labels == label] = class_id - 1 if reduce_labels else class_id + else: + labels = all_labels + + return binary_masks.astype(np.float32), labels.astype(np.int64) + + +def get_mask2former_resize_output_image_size( + image: np.ndarray, + size: Union[int, Tuple[int, int], List[int], Tuple[int]], + max_size: Optional[int] = None, + size_divisor: int = 0, + default_to_square: bool = True, +) -> tuple: + """ + Computes the output size given the desired size. + + Args: + input_image (`np.ndarray`): + The input image. + size (`int`, `Tuple[int, int]`, `List[int]`, `Tuple[int]`): + The size of the output image. + default_to_square (`bool`, *optional*, defaults to `True`): + Whether to default to square if no size is provided. + max_size (`int`, *optional*): + The maximum size of the output image. + size_divisible (`int`, *optional*, defaults to `0`): + If size_divisible is given, the output image size will be divisible by the number. + + Returns: + `Tuple[int, int]`: The output size. + """ + output_size = get_resize_output_image_size( + input_image=image, size=size, default_to_square=default_to_square, max_size=max_size + ) + + if size_divisor > 0: + height, width = output_size + height = int(math.ceil(height / size_divisor) * size_divisor) + width = int(math.ceil(width / size_divisor) * size_divisor) + output_size = (height, width) + + return output_size + + +class Mask2FormerImageProcessor(BaseImageProcessor): + r""" + Constructs a Mask2Former image processor. The image processor can be used to prepare image(s) and optional targets + for the model. + + This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input to a certain `size`. + size (`int`, *optional*, defaults to 800): + Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a + sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of + the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size * + height / width, size)`. + max_size (`int`, *optional*, defaults to 1333): + The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is + set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`): + An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`, + `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`, + `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set + to `True`. + size_divisor (`int`, *optional*, defaults to 32): + Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in + Swin Transformer. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input to a certain `scale`. + rescale_factor (`float`, *optional*, defaults to 1/ 255): + Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with mean and standard deviation. + image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean. + image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the + ImageNet std. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). + The background label will be replaced by `ignore_index`. + + """ + + model_input_names = ["pixel_values", "pixel_mask"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + size_divisor: int = 32, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + image_mean: Union[float, List[float]] = None, + image_std: Union[float, List[float]] = None, + ignore_index: Optional[int] = None, + reduce_labels: bool = False, + **kwargs + ): + if "size_divisibility" in kwargs: + warnings.warn( + "The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use " + "`size_divisor` instead.", + FutureWarning, + ) + size_divisor = kwargs.pop("size_divisibility") + if "max_size" in kwargs: + warnings.warn( + "The `max_size` argument is deprecated and will be removed in v4.27. Please use size['longest_edge']" + " instead.", + FutureWarning, + ) + # We make max_size a private attribute so we can pass it as a default value in the preprocess method whilst + # `size` can still be pass in as an int + self._max_size = kwargs.pop("max_size") + else: + self._max_size = 1333 + + size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size} + size = get_size_dict(size, max_size=self._max_size, default_to_square=False) + + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.size_divisor = size_divisor + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.ignore_index = ignore_index + self.reduce_labels = reduce_labels + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `Mask2FormerImageProcessor.from_pretrained(checkpoint, max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "size_divisibility" in kwargs: + image_processor_dict["size_divisibility"] = kwargs.pop("size_divisibility") + return super().from_dict(image_processor_dict, **kwargs) + + @property + def size_divisibility(self): + warnings.warn( + "The `size_divisibility` property is deprecated and will be removed in v4.27. Please use " + "`size_divisor` instead.", + FutureWarning, + ) + return self.size_divisor + + @property + def max_size(self): + warnings.warn( + "The `max_size` property is deprecated and will be removed in v4.27. Please use size['longest_edge']" + " instead.", + FutureWarning, + ) + return self.size["longest_edge"] + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + size_divisor: int = 0, + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format=None, + **kwargs + ) -> np.ndarray: + """ + Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + """ + if "max_size" in kwargs: + warnings.warn( + "The `max_size` parameter is deprecated and will be removed in v4.27. " + "Please specify in `size['longest_edge'] instead`.", + FutureWarning, + ) + max_size = kwargs.pop("max_size") + else: + max_size = None + size = get_size_dict(size, max_size=max_size, default_to_square=False) + if "shortest_edge" in size and "longest_edge" in size: + size, max_size = size["shortest_edge"], size["longest_edge"] + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + max_size = None + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + size = get_mask2former_resize_output_image_size( + image=image, + size=size, + max_size=max_size, + size_divisor=size_divisor, + default_to_square=False, + ) + image = resize(image, size=size, resample=resample, data_format=data_format) + return image + + def rescale( + self, image: np.ndarray, rescale_factor: float, data_format: Optional[ChannelDimension] = None + ) -> np.ndarray: + """ + Rescale the image by the given factor. + """ + return rescale(image, rescale_factor, data_format=data_format) + + def normalize( + self, + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[ChannelDimension] = None, + ) -> np.ndarray: + """ + Normalize the image with the given mean and standard deviation. + """ + return normalize(image, mean=mean, std=std, data_format=data_format) + + def convert_segmentation_map_to_binary_masks( + self, + segmentation_map: "np.ndarray", + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + ignore_index: Optional[int] = None, + reduce_labels: bool = False, + ): + reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels + ignore_index = ignore_index if ignore_index is not None else self.ignore_index + return convert_segmentation_map_to_binary_masks( + segmentation_map=segmentation_map, + instance_id_to_semantic_id=instance_id_to_semantic_id, + ignore_index=ignore_index, + reduce_labels=reduce_labels, + ) + + def __call__(self, images, segmentation_maps=None, **kwargs) -> BatchFeature: + return self.preprocess(images, segmentation_maps=segmentation_maps, **kwargs) + + def _preprocess( + self, + image: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + size_divisor: int = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + ): + if do_resize: + image = self.resize(image, size=size, size_divisor=size_divisor, resample=resample) + if do_rescale: + image = self.rescale(image, rescale_factor=rescale_factor) + if do_normalize: + image = self.normalize(image, mean=image_mean, std=image_std) + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + size_divisor: int = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + image = self._preprocess( + image=image, + do_resize=do_resize, + size=size, + size_divisor=size_divisor, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format) + return image + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + size_divisor: int = 0, + ) -> np.ndarray: + """Preprocesses a single mask.""" + segmentation_map = to_numpy_array(segmentation_map) + # Add channel dimension if missing - needed for certain transformations + added_channel_dim = False + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + # TODO: (Amy) + # Remork segmentation map processing to include reducing labels and resizing which doesn't + # drop segment IDs > 255. + segmentation_map = self._preprocess( + image=segmentation_map, + do_resize=do_resize, + resample=PILImageResampling.NEAREST, + size=size, + size_divisor=size_divisor, + do_rescale=False, + do_normalize=False, + ) + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + return segmentation_map + + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + instance_id_to_semantic_id: Optional[Dict[int, int]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + size_divisor: Optional[int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + ignore_index: Optional[int] = None, + reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + **kwargs + ) -> BatchFeature: + if "pad_and_return_pixel_mask" in kwargs: + warnings.warn( + "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version", + FutureWarning, + ) + + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False, max_size=self._max_size) + size_divisor = size_divisor if size_divisor is not None else self.size_divisor + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + ignore_index = ignore_index if ignore_index is not None else self.ignore_index + reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels + + if do_resize is not None and size is None or size_divisor is None: + raise ValueError("If `do_resize` is True, `size` and `size_divisor` must be provided.") + + if do_rescale is not None and rescale_factor is None: + raise ValueError("If `do_rescale` is True, `rescale_factor` must be provided.") + + if do_normalize is not None and (image_mean is None or image_std is None): + raise ValueError("If `do_normalize` is True, `image_mean` and `image_std` must be provided.") + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if segmentation_maps is not None and not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if not is_batched(images): + images = [images] + segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None + + if segmentation_maps is not None and len(images) != len(segmentation_maps): + raise ValueError("Images and segmentation maps must have the same length.") + + images = [ + self._preprocess_image( + image, + do_resize=do_resize, + size=size, + size_divisor=size_divisor, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + ) + for image in images + ] + + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_mask(segmentation_map, do_resize, size, size_divisor) + for segmentation_map in segmentation_maps + ] + encoded_inputs = self.encode_inputs( + images, segmentation_maps, instance_id_to_semantic_id, ignore_index, reduce_labels, return_tensors + ) + return encoded_inputs + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image + def _pad_image( + self, + image: np.ndarray, + output_size: Tuple[int, int], + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format + ) + return padded_image + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad + def pad( + self, + images: List[np.ndarray], + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + ) -> np.ndarray: + """ + Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width + in the batch and optionally returns their corresponding pixel mask. + + Args: + image (`np.ndarray`): + Image to pad. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + input_channel_dimension (`ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be inferred from the input image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + """ + pad_size = get_max_height_width(images) + + padded_images = [ + self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format) + for image in images + ] + data = {"pixel_values": padded_images} + + if return_pixel_mask: + masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images] + data["pixel_mask"] = masks + + return BatchFeature(data=data, tensor_type=return_tensors) + + def encode_inputs( + self, + pixel_values_list: List[ImageInput], + segmentation_maps: ImageInput = None, + instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None, + ignore_index: Optional[int] = None, + reduce_labels: bool = False, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ): + """ + Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. + + Mask2Former addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps + will be converted to lists of binary masks and their respective labels. Let's see an example, assuming + `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels = + [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for + each mask. + + Args: + pixel_values_list (`List[ImageInput]`): + List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, + width)`. + + segmentation_maps (`ImageInput`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. + + (`bool`, *optional*, defaults to `True`): + Whether or not to pad images up to the largest image in a batch and create a pixel mask. + + If left to the default, will return a pixel mask that is: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an + instance segmentation map where each pixel represents an instance id. Can be provided as a single + dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map + instance ids in each image separately. + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` + objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + - **pixel_mask** -- Pixel mask to be fed to a model (when `=True` or if `pixel_mask` is in + `self.model_input_names`). + - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model + (when `annotations` are provided). + - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when + `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of + `mask_labels[i][j]` if `class_labels[i][j]`. + """ + ignore_index = self.ignore_index if ignore_index is None else ignore_index + reduce_labels = self.reduce_labels if reduce_labels is None else reduce_labels + + if "pad_and_return_pixel_mask" in kwargs: + warnings.warn( + "The `pad_and_return_pixel_mask` argument has no effect and will be removed in v4.27", FutureWarning + ) + + pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list] + encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors) + + if segmentation_maps is not None: + mask_labels = [] + class_labels = [] + pad_size = get_max_height_width(pixel_values_list) + # Convert to list of binary masks and labels + for idx, segmentation_map in enumerate(segmentation_maps): + segmentation_map = to_numpy_array(segmentation_map) + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = self.convert_segmentation_map_to_binary_masks( + segmentation_map, instance_id, ignore_index=ignore_index, reduce_labels=reduce_labels + ) + # We add an axis to make them compatible with the transformations library + # this will be removed in the future + masks = [mask[None, ...] for mask in masks] + masks = [ + self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks + ] + masks = np.concatenate(masks, axis=0) + mask_labels.append(torch.from_numpy(masks)) + class_labels.append(torch.from_numpy(classes)) + + # we cannot batch them since they don't share a common class size + encoded_inputs["mask_labels"] = mask_labels + encoded_inputs["class_labels"] = class_labels + + return encoded_inputs + + def post_process_semantic_segmentation( + self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None + ) -> "torch.Tensor": + """ + Converts the output of [`Mask2FormerForUniversalSegmentation`] into semantic segmentation maps. Only supports + PyTorch. + + Args: + outputs ([`Mask2FormerForUniversalSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple[int, int]]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + Returns: + `List[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + # Scale back to preprocessed image size - (384, 384) for all models + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False + ) + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = torch.nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + def post_process_instance_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[List[Tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + return_binary_maps: Optional[bool] = False, + ) -> List[Dict]: + """ + Converts the output of [`Mask2FormerForUniversalSegmentationOutput`] into instance segmentation predictions. + Only supports PyTorch. + + Args: + outputs ([`Mask2FormerForUniversalSegmentation`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If left to None, predictions will not be resized. + return_coco_annotation (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format. + return_binary_maps (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps + (one per detected instance). + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`. Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- An integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + if return_coco_annotation and return_binary_maps: + raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.") + + # [batch_size, num_queries, num_classes+1] + class_queries_logits = outputs.class_queries_logits + # [batch_size, num_queries, height, width] + masks_queries_logits = outputs.masks_queries_logits + + # Scale back to preprocessed image size - (384, 384) for all models + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False + ) + + device = masks_queries_logits.device + num_classes = class_queries_logits.shape[-1] - 1 + num_queries = class_queries_logits.shape[-2] + + # Loop over items in batch size + results: List[Dict[str, TensorType]] = [] + + for i in range(class_queries_logits.shape[0]): + mask_pred = masks_queries_logits[i] + mask_cls = class_queries_logits[i] + + scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1] + labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) + + scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False) + labels_per_image = labels[topk_indices] + + topk_indices = topk_indices // num_classes + mask_pred = mask_pred[topk_indices] + pred_masks = (mask_pred > 0).float() + + # Calculate average mask prob + mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / ( + pred_masks.flatten(1).sum(1) + 1e-6 + ) + pred_scores = scores_per_image * mask_scores_per_image + pred_classes = labels_per_image + + segmentation = torch.zeros((384, 384)) - 1 + if target_sizes is not None: + segmentation = torch.zeros(target_sizes[i]) - 1 + pred_masks = torch.nn.functional.interpolate( + pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest" + )[0] + + instance_maps, segments = [], [] + current_segment_id = 0 + for j in range(num_queries): + score = pred_scores[j].item() + + if not torch.all(pred_masks[j] == 0) and score >= threshold: + segmentation[pred_masks[j] == 1] = current_segment_id + segments.append( + { + "id": current_segment_id, + "label_id": pred_classes[j].item(), + "was_fused": False, + "score": round(score, 6), + } + ) + current_segment_id += 1 + instance_maps.append(pred_masks[j]) + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + # Return a concatenated tensor of binary instance maps + if return_binary_maps and len(instance_maps) != 0: + segmentation = torch.stack(instance_maps, dim=0) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[Set[int]] = None, + target_sizes: Optional[List[Tuple[int, int]]] = None, + ) -> List[Dict]: + """ + Converts the output of [`Mask2FormerForUniversalSegmentationOutput`] into image panoptic segmentation + predictions. Only supports PyTorch. + + Args: + outputs ([`Mask2FormerForUniversalSegmentationOutput`]): + The outputs from [`Mask2FormerForUniversalSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If left to None, predictions will not be + resized. + + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set + to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized + to the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + + # Scale back to preprocessed image size - (384, 384) for all models + masks_queries_logits = torch.nn.functional.interpolate( + masks_queries_logits, size=(384, 384), mode="bilinear", align_corners=False + ) + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: List[Dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index abcba3b65a36e0..9f6feaa87b27c7 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -49,6 +49,7 @@ _CONFIG_FOR_DOC = "Mask2FormerConfig" _CHECKPOINT_FOR_DOC = "facebook/mask2former-swin-small-coco-instance" +_IMAGE_PROCESSOR_FOR_DOC = "Mask2FormerImageProcessor" MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ "facebook/mask2former-swin-small-coco-instance", @@ -194,10 +195,10 @@ class Mask2FormerForUniversalSegmentationOutput(ModelOutput): """ Class for outputs of [`Mask2FormerForUniversalSegmentationOutput`]. - This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or - [`~MaskFormerImageProcessor.post_process_instance_segmentation`] or - [`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see - [`~MaskFormerImageProcessor] for details regarding usage. + This output can be directly passed to [`~Mask2FormerImageProcessor.post_process_semantic_segmentation`] or + [`~Mask2FormerImageProcessor.post_process_instance_segmentation`] or + [`~Mask2FormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see + [`~Mask2FormerImageProcessor] for details regarding usage. Args: loss (`torch.Tensor`, *optional*): diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py index 255522ec7b017f..89f6c6920a8d4b 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer.py +++ b/src/transformers/models/maskformer/image_processing_maskformer.py @@ -1016,6 +1016,7 @@ def post_process_instance_segmentation( overlap_mask_area_threshold: float = 0.8, target_sizes: Optional[List[Tuple[int, int]]] = None, return_coco_annotation: Optional[bool] = False, + return_binary_maps: Optional[bool] = False, ) -> List[Dict]: """ Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into instance segmentation predictions. Only @@ -1034,9 +1035,11 @@ def post_process_instance_segmentation( target_sizes (`List[Tuple]`, *optional*): List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested final size (height, width) of each prediction. If left to None, predictions will not be resized. - return_coco_annotation (`bool`, *optional*): - Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) - format. + return_coco_annotation (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format. + return_binary_maps (`bool`, *optional*, defaults to `False`): + If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps + (one per detected instance). Returns: `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or @@ -1047,47 +1050,73 @@ def post_process_instance_segmentation( - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. - **score** -- Prediction score of segment with `segment_id`. """ - class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] - masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] - - batch_size = class_queries_logits.shape[0] - num_labels = class_queries_logits.shape[-1] - 1 + if return_coco_annotation and return_binary_maps: + raise ValueError("return_coco_annotation and return_binary_maps can not be both set to True.") - mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + # [batch_size, num_queries, num_classes+1] + class_queries_logits = outputs.class_queries_logits + # [batch_size, num_queries, height, width] + masks_queries_logits = outputs.masks_queries_logits - # Predicted label and score of each query (batch_size, num_queries) - pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + device = masks_queries_logits.device + num_classes = class_queries_logits.shape[-1] - 1 + num_queries = class_queries_logits.shape[-2] # Loop over items in batch size results: List[Dict[str, TensorType]] = [] - for i in range(batch_size): - mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( - mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels - ) + for i in range(class_queries_logits.shape[0]): + mask_pred = masks_queries_logits[i] + mask_cls = class_queries_logits[i] - # No mask found - if mask_probs_item.shape[0] <= 0: - height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] - segmentation = torch.zeros((height, width)) - 1 - results.append({"segmentation": segmentation, "segments_info": []}) - continue + scores = torch.nn.functional.softmax(mask_cls, dim=-1)[:, :-1] + labels = torch.arange(num_classes, device=device).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) - # Get segmentation map and segment information of batch item - target_size = target_sizes[i] if target_sizes is not None else None - segmentation, segments = compute_segments( - mask_probs=mask_probs_item, - pred_scores=pred_scores_item, - pred_labels=pred_labels_item, - mask_threshold=mask_threshold, - overlap_mask_area_threshold=overlap_mask_area_threshold, - label_ids_to_fuse=[], - target_size=target_size, - ) + scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False) + labels_per_image = labels[topk_indices] - # Return segmentation map in run-length encoding (RLE) format - if return_coco_annotation: - segmentation = convert_segmentation_to_rle(segmentation) + topk_indices = topk_indices // num_classes + mask_pred = mask_pred[topk_indices] + pred_masks = (mask_pred > 0).float() + + # Calculate average mask prob + mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / ( + pred_masks.flatten(1).sum(1) + 1e-6 + ) + pred_scores = scores_per_image * mask_scores_per_image + pred_classes = labels_per_image + + segmentation = torch.zeros(masks_queries_logits.shape[2:]) - 1 + if target_sizes is not None: + segmentation = torch.zeros(target_sizes[i]) - 1 + pred_masks = torch.nn.functional.interpolate( + pred_masks.unsqueeze(0), size=target_sizes[i], mode="nearest" + )[0] + + instance_maps, segments = [], [] + current_segment_id = 0 + for j in range(num_queries): + score = pred_scores[j].item() + + if not torch.all(pred_masks[j] == 0) and score >= threshold: + segmentation[pred_masks[j] == 1] = current_segment_id + segments.append( + { + "id": current_segment_id, + "label_id": pred_classes[j].item(), + "was_fused": False, + "score": round(score, 6), + } + ) + current_segment_id += 1 + instance_maps.append(pred_masks[j]) + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + # Return a concatenated tensor of binary instance maps + if return_binary_maps and len(instance_maps) != 0: + segmentation = torch.stack(instance_maps, dim=0) results.append({"segmentation": segmentation, "segments_info": segments}) return results diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 739ea3a618585a..6cde6701cf6b37 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -269,6 +269,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class Mask2FormerImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class MaskFormerFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/mask2former/test_image_processing_mask2former.py b/tests/models/mask2former/test_image_processing_mask2former.py new file mode 100644 index 00000000000000..983b4fa7e6db0c --- /dev/null +++ b/tests/models/mask2former/test_image_processing_mask2former.py @@ -0,0 +1,611 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +from datasets import load_dataset + +from huggingface_hub import hf_hub_download +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingSavingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + + if is_vision_available(): + from transformers import Mask2FormerImageProcessor + from transformers.models.mask2former.image_processing_mask2former import binary_mask_to_rle + from transformers.models.mask2former.modeling_mask2former import Mask2FormerForUniversalSegmentationOutput + +if is_vision_available(): + from PIL import Image + + +class Mask2FormerImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=400, + size=None, + do_resize=True, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + num_labels=10, + do_reduce_labels=True, + ignore_index=255, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = {"shortest_edge": 32, "longest_edge": 1333} if size is None else size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.size_divisor = 0 + # for the post_process_functions + self.batch_size = 2 + self.num_queries = 3 + self.num_classes = 2 + self.height = 3 + self.width = 4 + self.num_labels = num_labels + self.do_reduce_labels = do_reduce_labels + self.ignore_index = ignore_index + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "size_divisor": self.size_divisor, + "num_labels": self.num_labels, + "do_reduce_labels": self.do_reduce_labels, + "ignore_index": self.ignore_index, + } + + def get_expected_values(self, image_inputs, batched=False): + """ + This function computes the expected height and width when providing images to Mask2FormerImageProcessor, + assuming do_resize is set to True with a scalar size. + """ + if not batched: + image = image_inputs[0] + if isinstance(image, Image.Image): + w, h = image.size + else: + h, w = image.shape[1], image.shape[2] + if w < h: + expected_height = int(self.size["shortest_edge"] * h / w) + expected_width = self.size["shortest_edge"] + elif w > h: + expected_height = self.size["shortest_edge"] + expected_width = int(self.size["shortest_edge"] * w / h) + else: + expected_height = self.size["shortest_edge"] + expected_width = self.size["shortest_edge"] + + else: + expected_values = [] + for image in image_inputs: + expected_height, expected_width = self.get_expected_values([image]) + expected_values.append((expected_height, expected_width)) + expected_height = max(expected_values, key=lambda item: item[0])[0] + expected_width = max(expected_values, key=lambda item: item[1])[1] + + return expected_height, expected_width + + def get_fake_mask2former_outputs(self): + return Mask2FormerForUniversalSegmentationOutput( + # +1 for null class + class_queries_logits=torch.randn((self.batch_size, self.num_queries, self.num_classes + 1)), + masks_queries_logits=torch.randn((self.batch_size, self.num_queries, self.height, self.width)), + ) + + +@require_torch +@require_vision +class Mask2FormerImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase): + + image_processing_class = Mask2FormerImageProcessor if (is_vision_available() and is_torch_available()) else None + + def setUp(self): + self.image_processor_tester = Mask2FormerImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "max_size")) + self.assertTrue(hasattr(image_processing, "ignore_index")) + self.assertTrue(hasattr(image_processing, "num_labels")) + + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 32, "longest_edge": 1333}) + self.assertEqual(image_processor.size_divisor, 0) + + image_processor = self.image_processing_class.from_dict( + self.image_processor_dict, size=42, max_size=84, size_divisibility=8 + ) + self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 84}) + self.assertEqual(image_processor.size_divisor, 8) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + + expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs) + + self.assertEqual( + encoded_images.shape, + (1, self.image_processor_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True) + + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_call_numpy(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + + expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs) + + self.assertEqual( + encoded_images.shape, + (1, self.image_processor_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + + expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True) + + self.assertEqual( + encoded_images.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_call_pytorch(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + + expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs) + + self.assertEqual( + encoded_images.shape, + (1, self.image_processor_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + + expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True) + + self.assertEqual( + encoded_images.shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_equivalence_pad_and_create_pixel_mask(self): + # Initialize image_processings + image_processing_1 = self.image_processing_class(**self.image_processor_dict) + image_processing_2 = self.image_processing_class( + do_resize=False, do_normalize=False, do_rescale=False, num_labels=self.image_processor_tester.num_classes + ) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test whether the method "pad_and_return_pixel_mask" and calling the image processor return the same tensors + encoded_images_with_method = image_processing_1.encode_inputs(image_inputs, return_tensors="pt") + encoded_images = image_processing_2(image_inputs, return_tensors="pt") + + self.assertTrue( + torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4) + ) + self.assertTrue( + torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4) + ) + + def comm_get_image_processing_inputs( + self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np" + ): + image_processing = self.image_processing_class(**self.image_processor_dict) + # prepare image and target + num_labels = self.image_processor_tester.num_labels + annotations = None + instance_id_to_semantic_id = None + image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False) + if with_segmentation_maps: + high = num_labels + if is_instance_map: + labels_expanded = list(range(num_labels)) * 2 + instance_id_to_semantic_id = { + instance_id: label_id for instance_id, label_id in enumerate(labels_expanded) + } + annotations = [ + np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs + ] + if segmentation_type == "pil": + annotations = [Image.fromarray(annotation) for annotation in annotations] + + inputs = image_processing( + image_inputs, + annotations, + return_tensors="pt", + instance_id_to_semantic_id=instance_id_to_semantic_id, + pad_and_return_pixel_mask=True, + ) + + return inputs + + def test_init_without_params(self): + pass + + def test_with_size_divisor(self): + size_divisors = [8, 16, 32] + weird_input_sizes = [(407, 802), (582, 1094)] + for size_divisor in size_divisors: + image_processor_dict = {**self.image_processor_dict, **{"size_divisor": size_divisor}} + image_processing = self.image_processing_class(**image_processor_dict) + for weird_input_size in weird_input_sizes: + inputs = image_processing([np.ones((3, *weird_input_size))], return_tensors="pt") + pixel_values = inputs["pixel_values"] + # check if divisible + self.assertTrue((pixel_values.shape[-1] % size_divisor) == 0) + self.assertTrue((pixel_values.shape[-2] % size_divisor) == 0) + + def test_call_with_segmentation_maps(self): + def common(is_instance_map=False, segmentation_type=None): + inputs = self.comm_get_image_processing_inputs( + with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type + ) + + mask_labels = inputs["mask_labels"] + class_labels = inputs["class_labels"] + pixel_values = inputs["pixel_values"] + + # check the batch_size + for mask_label, class_label in zip(mask_labels, class_labels): + self.assertEqual(mask_label.shape[0], class_label.shape[0]) + # this ensure padding has happened + self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:]) + + common() + common(is_instance_map=True) + common(is_instance_map=False, segmentation_type="pil") + common(is_instance_map=True, segmentation_type="pil") + + def test_integration_instance_segmentation(self): + # load 2 images and corresponding annotations from the hub + repo_id = "nielsr/image-segmentation-toy-data" + image1 = Image.open( + hf_hub_download(repo_id=repo_id, filename="instance_segmentation_image_1.png", repo_type="dataset") + ) + image2 = Image.open( + hf_hub_download(repo_id=repo_id, filename="instance_segmentation_image_2.png", repo_type="dataset") + ) + annotation1 = Image.open( + hf_hub_download(repo_id=repo_id, filename="instance_segmentation_annotation_1.png", repo_type="dataset") + ) + annotation2 = Image.open( + hf_hub_download(repo_id=repo_id, filename="instance_segmentation_annotation_2.png", repo_type="dataset") + ) + + # get instance segmentations and instance-to-segmentation mappings + def get_instance_segmentation_and_mapping(annotation): + instance_seg = np.array(annotation)[:, :, 1] + class_id_map = np.array(annotation)[:, :, 0] + class_labels = np.unique(class_id_map) + + # create mapping between instance IDs and semantic category IDs + inst2class = {} + for label in class_labels: + instance_ids = np.unique(instance_seg[class_id_map == label]) + inst2class.update({i: label for i in instance_ids}) + + return instance_seg, inst2class + + instance_seg1, inst2class1 = get_instance_segmentation_and_mapping(annotation1) + instance_seg2, inst2class2 = get_instance_segmentation_and_mapping(annotation2) + + # create a image processor + image_processing = Mask2FormerImageProcessor(reduce_labels=True, ignore_index=255, size=(512, 512)) + + # prepare the images and annotations + inputs = image_processing( + [image1, image2], + [instance_seg1, instance_seg2], + instance_id_to_semantic_id=[inst2class1, inst2class2], + return_tensors="pt", + ) + + # verify the pixel values and pixel mask + self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512)) + self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512)) + + # verify the class labels + self.assertEqual(len(inputs["class_labels"]), 2) + self.assertTrue(torch.allclose(inputs["class_labels"][0], torch.tensor([30, 55]))) + self.assertTrue(torch.allclose(inputs["class_labels"][1], torch.tensor([4, 4, 23, 55]))) + + # verify the mask labels + self.assertEqual(len(inputs["mask_labels"]), 2) + self.assertEqual(inputs["mask_labels"][0].shape, (2, 512, 512)) + self.assertEqual(inputs["mask_labels"][1].shape, (4, 512, 512)) + self.assertEquals(inputs["mask_labels"][0].sum().item(), 41527.0) + self.assertEquals(inputs["mask_labels"][1].sum().item(), 26259.0) + + def test_integration_semantic_segmentation(self): + # load 2 images and corresponding semantic annotations from the hub + repo_id = "nielsr/image-segmentation-toy-data" + image1 = Image.open( + hf_hub_download(repo_id=repo_id, filename="semantic_segmentation_image_1.png", repo_type="dataset") + ) + image2 = Image.open( + hf_hub_download(repo_id=repo_id, filename="semantic_segmentation_image_2.png", repo_type="dataset") + ) + annotation1 = Image.open( + hf_hub_download(repo_id=repo_id, filename="semantic_segmentation_annotation_1.png", repo_type="dataset") + ) + annotation2 = Image.open( + hf_hub_download(repo_id=repo_id, filename="semantic_segmentation_annotation_2.png", repo_type="dataset") + ) + + # create a image processor + image_processing = Mask2FormerImageProcessor(reduce_labels=True, ignore_index=255, size=(512, 512)) + + # prepare the images and annotations + inputs = image_processing( + [image1, image2], + [annotation1, annotation2], + return_tensors="pt", + ) + + # verify the pixel values and pixel mask + self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 512)) + self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 512)) + + # verify the class labels + self.assertEqual(len(inputs["class_labels"]), 2) + self.assertTrue(torch.allclose(inputs["class_labels"][0], torch.tensor([2, 4, 60]))) + self.assertTrue(torch.allclose(inputs["class_labels"][1], torch.tensor([0, 3, 7, 8, 15, 28, 30, 143]))) + + # verify the mask labels + self.assertEqual(len(inputs["mask_labels"]), 2) + self.assertEqual(inputs["mask_labels"][0].shape, (3, 512, 512)) + self.assertEqual(inputs["mask_labels"][1].shape, (8, 512, 512)) + self.assertEquals(inputs["mask_labels"][0].sum().item(), 170200.0) + self.assertEquals(inputs["mask_labels"][1].sum().item(), 257036.0) + + def test_integration_panoptic_segmentation(self): + # load 2 images and corresponding panoptic annotations from the hub + dataset = load_dataset("nielsr/ade20k-panoptic-demo") + image1 = dataset["train"][0]["image"] + image2 = dataset["train"][1]["image"] + segments_info1 = dataset["train"][0]["segments_info"] + segments_info2 = dataset["train"][1]["segments_info"] + annotation1 = dataset["train"][0]["label"] + annotation2 = dataset["train"][1]["label"] + + def rgb_to_id(color): + if isinstance(color, np.ndarray) and len(color.shape) == 3: + if color.dtype == np.uint8: + color = color.astype(np.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + def create_panoptic_map(annotation, segments_info): + annotation = np.array(annotation) + # convert RGB to segment IDs per pixel + # 0 is the "ignore" label, for which we don't need to make binary masks + panoptic_map = rgb_to_id(annotation) + + # create mapping between segment IDs and semantic classes + inst2class = {segment["id"]: segment["category_id"] for segment in segments_info} + + return panoptic_map, inst2class + + panoptic_map1, inst2class1 = create_panoptic_map(annotation1, segments_info1) + panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2) + + # create a image processor + image_processing = Mask2FormerImageProcessor(ignore_index=0, do_resize=False) + + # prepare the images and annotations + pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)] + inputs = image_processing.encode_inputs( + pixel_values_list, + [panoptic_map1, panoptic_map2], + instance_id_to_semantic_id=[inst2class1, inst2class2], + return_tensors="pt", + ) + + # verify the pixel values and pixel mask + self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711)) + self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711)) + + # verify the class labels + self.assertEqual(len(inputs["class_labels"]), 2) + # fmt: off + expected_class_labels = torch.tensor([4, 17, 32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 3, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 5, 12, 12, 12, 12, 12, 12, 12, 0, 43, 43, 43, 96, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # noqa: E231 + # fmt: on + self.assertTrue(torch.allclose(inputs["class_labels"][0], torch.tensor(expected_class_labels))) + # fmt: off + expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 17, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 3, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 5, 12, 12, 0, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # noqa: E231 + # fmt: on + self.assertTrue(torch.allclose(inputs["class_labels"][1], expected_class_labels)) + + # verify the mask labels + self.assertEqual(len(inputs["mask_labels"]), 2) + self.assertEqual(inputs["mask_labels"][0].shape, (79, 512, 711)) + self.assertEqual(inputs["mask_labels"][1].shape, (61, 512, 711)) + self.assertEquals(inputs["mask_labels"][0].sum().item(), 315193.0) + self.assertEquals(inputs["mask_labels"][1].sum().item(), 350747.0) + + def test_binary_mask_to_rle(self): + fake_binary_mask = np.zeros((20, 50)) + fake_binary_mask[0, 20:] = 1 + fake_binary_mask[1, :15] = 1 + fake_binary_mask[5, :10] = 1 + + rle = binary_mask_to_rle(fake_binary_mask) + self.assertEqual(len(rle), 4) + self.assertEqual(rle[0], 21) + self.assertEqual(rle[1], 45) + + def test_post_process_semantic_segmentation(self): + fature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_mask2former_outputs() + + segmentation = fature_extractor.post_process_semantic_segmentation(outputs) + + self.assertEqual(len(segmentation), self.image_processor_tester.batch_size) + self.assertEqual(segmentation[0].shape, (384, 384)) + + target_sizes = [(1, 4) for i in range(self.image_processor_tester.batch_size)] + segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) + + self.assertEqual(segmentation[0].shape, target_sizes[0]) + + def test_post_process_instance_segmentation(self): + feature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_mask2former_outputs() + segmentation = feature_extractor.post_process_instance_segmentation(outputs, threshold=0) + + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(el["segmentation"].shape, (384, 384)) + + segmentation = feature_extractor.post_process_instance_segmentation( + outputs, threshold=0, return_binary_maps=True + ) + + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(len(el["segmentation"].shape), 3) + self.assertEqual(el["segmentation"].shape[1:], (384, 384)) + + def test_post_process_panoptic_segmentation(self): + image_processing = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_mask2former_outputs() + segmentation = image_processing.post_process_panoptic_segmentation(outputs, threshold=0) + + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(el["segmentation"].shape, (384, 384)) + + def test_post_process_label_fusing(self): + image_processor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_mask2former_outputs() + + segmentation = image_processor.post_process_panoptic_segmentation( + outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0 + ) + unfused_segments = [el["segments_info"] for el in segmentation] + + fused_segmentation = image_processor.post_process_panoptic_segmentation( + outputs, threshold=0, mask_threshold=0, overlap_mask_area_threshold=0, label_ids_to_fuse={1} + ) + fused_segments = [el["segments_info"] for el in fused_segmentation] + + for el_unfused, el_fused in zip(unfused_segments, fused_segments): + if len(el_unfused) == 0: + self.assertEqual(len(el_unfused), len(el_fused)) + continue + + # Get number of segments to be fused + fuse_targets = [1 for el in el_unfused if el["label_id"] in {1}] + num_to_fuse = 0 if len(fuse_targets) == 0 else sum(fuse_targets) - 1 + # Expected number of segments after fusing + expected_num_segments = max([el["id"] for el in el_unfused]) - num_to_fuse + num_segments_fused = max([el["id"] for el in el_fused]) + self.assertEqual(num_segments_fused, expected_num_segments) diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index a3e95088ef50b4..3a8645f64a88f5 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -34,7 +34,7 @@ from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerModel if is_vision_available(): - from transformers import MaskFormerImageProcessor + from transformers import Mask2FormerImageProcessor if is_vision_available(): from PIL import Image @@ -325,7 +325,7 @@ def model_checkpoints(self): @cached_property def default_feature_extractor(self): - return MaskFormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None + return Mask2FormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None def test_inference_no_head(self): model = Mask2FormerModel.from_pretrained(self.model_checkpoints).to(torch_device) diff --git a/tests/models/maskformer/test_image_processing_maskformer.py b/tests/models/maskformer/test_image_processing_maskformer.py index 2df4b8df925b9b..2d7b710b59da7f 100644 --- a/tests/models/maskformer/test_image_processing_maskformer.py +++ b/tests/models/maskformer/test_image_processing_maskformer.py @@ -576,6 +576,34 @@ def test_post_process_semantic_segmentation(self): self.assertEqual(segmentation[0].shape, target_sizes[0]) + def test_post_process_instance_segmentation(self): + feature_extractor = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) + outputs = self.image_processor_tester.get_fake_maskformer_outputs() + segmentation = feature_extractor.post_process_instance_segmentation(outputs, threshold=0) + + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual( + el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width) + ) + + segmentation = feature_extractor.post_process_instance_segmentation( + outputs, threshold=0, return_binary_maps=True + ) + + self.assertTrue(len(segmentation) == self.image_processor_tester.batch_size) + for el in segmentation: + self.assertTrue("segmentation" in el) + self.assertTrue("segments_info" in el) + self.assertEqual(type(el["segments_info"]), list) + self.assertEqual(len(el["segmentation"].shape), 3) + self.assertEqual( + el["segmentation"].shape[1:], (self.image_processor_tester.height, self.image_processor_tester.width) + ) + def test_post_process_panoptic_segmentation(self): image_processing = self.image_processing_class(num_labels=self.image_processor_tester.num_classes) outputs = self.image_processor_tester.get_fake_maskformer_outputs()