-
Notifications
You must be signed in to change notification settings - Fork 77
/
dataset.py
107 lines (87 loc) · 3.8 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import xml.etree.ElementTree as ET
import chainer
import numpy as np
from chainercv.datasets.voc import voc_utils
from chainercv.utils import read_image
from opt import bam_contents_classes
class BaseDetectionDataset(chainer.dataset.DatasetMixin):
def __init__(self, root, subset, use_difficult, return_difficult):
self.root = root
self.img_dir = os.path.join(root, 'JPEGImages')
self.imgset_dir = os.path.join(root, 'ImageSets/Main')
self.ann_dir = os.path.join(root, 'Annotations')
id_list_file = os.path.join(
self.imgset_dir, '{:s}.txt'.format(subset))
self.ids = [id_.strip() for id_ in open(id_list_file)]
self.use_difficult = use_difficult
self.return_difficult = return_difficult
self.subset = subset
self.labels = None # for network
self.actual_labels = None # for visualization
def __len__(self):
return len(self.ids)
def get_example(self, i):
"""Returns the i-th example.
Returns a color image and bounding boxes. The image is in CHW format.
The returned image is RGB.
Args:
i (int): The index of the example.
Returns:
tuple of an image and bounding boxes
"""
id_ = self.ids[i]
anno = ET.parse(
os.path.join(self.ann_dir, id_ + '.xml'))
bbox = []
label = []
difficult = []
objs = anno.findall('object')
for obj in objs:
# If not using difficult split, and the object is
# difficult, skip it.
if not self.use_difficult and int(obj.find('difficult').text) == 1:
continue
bndbox_anno = obj.find('bndbox')
name = obj.find('name').text.lower().strip()
label.append(self.labels.index(name))
# subtract 1 to make pixel indexes 0-based
bbox.append([
int(bndbox_anno.find(tag).text) - 1
for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
difficult.append(int(obj.find('difficult').text))
try:
bbox = np.stack(bbox).astype(np.float32)
label = np.stack(label).astype(np.int32)
difficult = np.array(difficult, dtype=np.bool)
except ValueError:
bbox = np.empty((0, 4), dtype=np.float32)
label = np.empty((0,), dtype=np.int32)
difficult = np.empty((0,), dtype=np.bool)
# Load a image
img_file = os.path.join(self.img_dir, id_ + '.jpg')
img = read_image(img_file, color=True)
if self.return_difficult:
return img, bbox, label, difficult
return img, bbox, label
class VOCDataset(BaseDetectionDataset):
def __init__(self, root, subset, use_difficult=False,
return_difficult=False):
super(VOCDataset, self).__init__(root, subset, use_difficult,
return_difficult)
self.labels = voc_utils.voc_bbox_label_names
self.actual_labels = voc_utils.voc_bbox_label_names
class ClipArtDataset(BaseDetectionDataset):
def __init__(self, root, subset, use_difficult=False,
return_difficult=False):
super(ClipArtDataset, self).__init__(root, subset, use_difficult,
return_difficult)
self.labels = voc_utils.voc_bbox_label_names
self.actual_labels = voc_utils.voc_bbox_label_names
class BAMDataset(BaseDetectionDataset):
def __init__(self, root, subset, use_difficult=False,
return_difficult=False):
super(BAMDataset, self).__init__(root, subset, use_difficult,
return_difficult)
self.labels = voc_utils.voc_bbox_label_names
self.actual_labels = bam_contents_classes