-
Notifications
You must be signed in to change notification settings - Fork 517
/
coco_segmentation.py
executable file
·167 lines (134 loc) · 7.8 KB
/
coco_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from typing import Tuple
try:
from pycocotools.coco import COCO
from pycocotools import mask as pycocotools_mask
except ModuleNotFoundError as ex:
print("[WARNING]" + str(ex))
from super_gradients.common.object_names import Datasets
from super_gradients.common.registry.registry import register_dataset
from super_gradients.training.datasets.datasets_conf import COCO_DEFAULT_CLASSES_TUPLES_LIST
from super_gradients.training.datasets.segmentation_datasets.segmentation_dataset import SegmentationDataSet
class EmptyCoCoClassesSelectionException(Exception):
pass
@register_dataset(Datasets.COCO_SEGMENTATION_DATASET)
class CoCoSegmentationDataSet(SegmentationDataSet):
"""
Segmentation Data Set Class for COCO 2017 Segmentation Data Set
To use this Dataset you need to:
- Download coco dataset:
annotations: http://images.cocodataset.org/annotations/annotations_trainval2017.zip
train2017: http://images.cocodataset.org/zips/train2017.zip
val2017: http://images.cocodataset.org/zips/val2017.zip
- Unzip and organize it as below:
coco
├── annotations
│ ├─ instances_train2017.json
│ ├─ instances_val2017.json
│ └─ ...
└── images
├── train2017
│ ├─ 000000000001.jpg
│ └─ ...
└── val2017
└─ ...
- Instantiate the dataset:
>> train_set = CoCoSegmentationDataSet(data_dir='.../coco', subdir='images/train2017', json_file='instances_train2017.json', ...)
>> valid_set = CoCoSegmentationDataSet(data_dir='.../coco', subdir='images/val2017', json_file='instances_val2017.json', ...)
"""
def __init__(self, root_dir: str, dataset_classes_inclusion_tuples_list: list = None, *args, **kwargs):
# THERE ARE 91 CLASSES, INCLUDING BACKGROUND - BUT WE ENABLE THE USAGE OF SUBCLASSES, TO PARTIALLY USE THE DATA
self.dataset_classes_inclusion_tuples_list = dataset_classes_inclusion_tuples_list or COCO_DEFAULT_CLASSES_TUPLES_LIST
self.root_dir = root_dir
super().__init__(root_dir, *args, **kwargs)
_, class_names = zip(*self.dataset_classes_inclusion_tuples_list)
self.classes = class_names
def _generate_samples_and_targets(self):
"""
_generate_samples_and_targets
"""
# FIRST OF ALL LOAD ALL OF THE ANNOTATIONS, AND CREATE THE PATH FOR THE PRE-PROCESSED MASKS
self.annotations_file_path = os.path.join(self.root, self.targets_sub_directory, self.list_file_path)
self.coco = COCO(self.annotations_file_path)
# USE SUB-CLASSES OF THE ENTIRE COCO DATA SET, INSTEAD ALL OF THE DATA -> HIGHLY RELEVANT FOR TRANSFER LEARNING
sub_dataset_image_ids_file_path = self.annotations_file_path.replace("json", "pth")
if os.path.exists(sub_dataset_image_ids_file_path):
self.relevant_image_ids = torch.load(sub_dataset_image_ids_file_path)
else:
self.relevant_image_ids = self._sub_dataset_creation(sub_dataset_image_ids_file_path)
for relevant_image_id in self.relevant_image_ids:
img_metadata = self.coco.loadImgs(relevant_image_id)[0]
image_path = os.path.join(self.root, self.samples_sub_directory, img_metadata["file_name"])
mask_metadata_tuple = (relevant_image_id, img_metadata["height"], img_metadata["width"])
self.samples_targets_tuples_list.append((image_path, mask_metadata_tuple))
super(CoCoSegmentationDataSet, self)._generate_samples_and_targets()
def target_loader(self, mask_metadata_tuple) -> Image:
"""
target_loader
:param mask_metadata_tuple: A tuple of (coco_image_id, original_image_height, original_image_width)
:return: The mask image created from the array
"""
coco_image_id, original_image_h, original_image_w = mask_metadata_tuple
coco_annotations = self.coco.loadAnns(self.coco.getAnnIds(imgIds=coco_image_id))
mask = self._generate_coco_segmentation_mask(coco_annotations, original_image_h, original_image_w)
return Image.fromarray(mask)
def _generate_coco_segmentation_mask(self, target_coco_annotations, h, w):
"""
_generate_segmentation_mask - Extracts a segmentation mask
:param target_coco_annotations:
:param h:
:param w:
:return:
"""
mask = np.zeros((h, w), dtype=np.uint8)
for i, instance in enumerate(target_coco_annotations):
rle = pycocotools_mask.frPyObjects(instance["segmentation"], h, w)
coco_segementation_mask = pycocotools_mask.decode(rle)
if not self.dataset_classes_inclusion_tuples_list:
# NO CLASSES WERE SELECTED FROM COCO'S 91 CLASSES - ERROR
raise EmptyCoCoClassesSelectionException
else:
# FILTER OUT ALL OF THE MASKS OF INSTANCES THAT ARE NOT IN THE SUB-DATASET CLASSES
class_category = instance["category_id"]
sub_classes_category_ids, _ = map(list, zip(*self.dataset_classes_inclusion_tuples_list))
if class_category not in sub_classes_category_ids:
continue
class_index = sub_classes_category_ids.index(class_category)
if len(coco_segementation_mask.shape) < 3:
mask[:, :] += (mask == 0) * (coco_segementation_mask * class_index)
else:
mask[:, :] += (mask == 0) * (((np.sum(coco_segementation_mask, axis=2)) > 0) * class_index).astype(np.uint8)
return mask
def _sub_dataset_creation(self, sub_dataset_image_ids_file_path) -> list:
"""
_sub_dataset_creation - This method creates the segmentation annotations for coco using
self._generate_segmentation_mask that uses the sub-classes inclusion tuple to keep only
the annotations that are relevant to the sub-classes selected when instantiating the class
:param sub_dataset_image_ids_file_path: The path to save the sub-dataset in for future loading
:return: All of the ids with enough pixel data after the sub-classing
"""
print("Creating sub-dataset , this will take a while but don't worry, it only runs once and caches the results")
all_coco_image_ids = list(self.coco.imgs.keys())
sub_dataset_image_ids = []
with tqdm(all_coco_image_ids, desc="Generating sub-dataset image ids") as tbar:
for i, img_id in enumerate(tbar):
coco_target_annotations = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
img_metadata = self.coco.loadImgs(img_id)[0]
mask = self._generate_coco_segmentation_mask(coco_target_annotations, img_metadata["height"], img_metadata["width"])
# MAKE SURE THERE IS ENOUGH INPUT IN THE IMAGE (MORE THAN 1K PIXELS) AFTER SUB-CLASSES FILTRATION
if (mask > 0).sum() > 1000:
sub_dataset_image_ids.append(img_id)
tbar.set_description("Processed images: {}/{}, generated {} qualified images".format(i, len(all_coco_image_ids), len(sub_dataset_image_ids)))
print("Number of images in sub-dataset: ", len(sub_dataset_image_ids))
torch.save(sub_dataset_image_ids, sub_dataset_image_ids_file_path)
return sub_dataset_image_ids
@property
def _original_dataset_image_shape(self) -> Tuple[int, int]:
"""
returns image shape when data set contains images of uniform shape.
"""
return 512, 512