-
Notifications
You must be signed in to change notification settings - Fork 2
/
datasets.py
139 lines (102 loc) · 5.93 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import glob
import os
import scipy.io as sio
from torch.utils.data import Dataset # Dataset class from PyTorch
from PIL import Image, ImageChops # PIL is a nice Python Image Library that we can use to handle images
import torchvision.transforms as transforms # torch transform used for computer vision applications
import numpy as np
import torch
# import sys
# https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
def get_clothCoParse_class_names():
# names ordered according to label id, 0 for background and 59 for wedges
ClothCoParse_class_names = ['background', 'accessories', 'bag', 'belt', 'blazer',
'blouse', 'bodysuit', 'boots', 'bra', 'bracelet', 'cape', 'cardigan',
'clogs', 'coat', 'dress', 'earrings', 'flats', 'glasses', 'gloves', 'hair',
'hat', 'heels', 'hoodie', 'intimate', 'jacket', 'jeans', 'jumper', 'leggings',
'loafers', 'necklace', 'panties', 'pants', 'pumps', 'purse', 'ring', 'romper',
'sandals', 'scarf', 'shirt', 'shoes', 'shorts', 'skin', 'skirt', 'sneakers',
'socks', 'stockings', 'suit', 'sunglasses', 'sweater', 'sweatshirt', 'swimwear',
't-shirt', 'tie', 'tights', 'top', 'vest', 'wallet', 'watch', 'wedges']
return ClothCoParse_class_names
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, transforms_target=None,
mode="train", person_detection=False,
HPC_run=False, remove_background=True,
):
self.remove_background = remove_background # we'll have to add it as an argument later
self.person_detection =person_detection
if transforms_ != None:
self.transforms = transforms.Compose(transforms_) # image transform
else: self.transforms=None
if transforms_target != None:
self.transforms_target = transforms.Compose(transforms_target) # image transform
else: self.transforms_target=None
if HPC_run:
root = '/home/malrawi/MyPrograms/Data/ClothCoParse'
self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*")) # get the source image file-names
self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*")) # get the target image file-names
def number_of_classes(self, opt):
if opt.person_detection:
return 2
else:
return(len(get_clothCoParse_class_names())) # this should do
def __getitem__(self, index):
annot = sio.loadmat(self.files_B[index % len(self.files_B)])
mask = annot["groundtruth"]
image_A = Image.open(self.files_A[index % len(self.files_A)]) # read the image, according to the file name, index select which image to read; index=1 means get the first image in the list self.files_A
if self.remove_background or self.person_detection:
mm = np.int8(mask>0) # thresholding the mask
if self.person_detection:
mask = mm # this is a binary mask; Image.fromarray(255*mask).show()
self.remove_background=False # background should not be removed in person-detection
if self.remove_background:
image_A = ImageChops.multiply(image_A, Image.fromarray(255*mm).convert('RGB') )
# instances are encoded as different colors
obj_ids = np.unique(mask)[1:] # first id is the background, so remove it
masks = mask == obj_ids[:, None, None] # split the color-encoded mask into a set of binary masks
# get bounding box coordinates for each mask
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
boxes.append([xmin, ymin, xmax, ymax])
# convert everything into torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
target = {}
target["boxes"] = boxes
target["labels"] = torch.as_tensor(obj_ids, dtype=torch.int64) # corrected by Rawi
target["masks"] = torch.as_tensor(masks, dtype=torch.uint8) #uint8
target["image_id"] = torch.tensor([index])
target["area"] = area
target["iscrowd"] = torch.zeros((num_objs,), dtype=torch.int64) # suppose all instances are not crowd
if self.transforms != None:
img = self.transforms(image_A)
if self.transforms_target != None:
target = self.transforms_target(target)
return img, target
def __len__(self): # this function returns the length of the dataset, the source might not equal the target if the data is unaligned
return len(self.files_B)
# transforms_ = [
# transforms.Resize((300, 300), Image.BICUBIC),
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ]
# x_data = ImageDataset("../data/%s" % "ClothCoParse",
# transforms_= '', #transforms_,
# mode = "train",
# HPC_run = False,
# )
# for i in range(len(x_data)):
# print(i)
# z= x_data[i] #accessing the first element in the data, should have the first image and its corresponding pixel-levele annotation
# x_data[0][1]
# # plt.imshow(anno.convert('L'), cmap= plt.cm.get_cmap("gist_stern"), vmin=0, vmax=255)
# if num_objs==0: # this can/should be used to (data cleaning) remove pairs with no annotations as these will cause an error
# print('############ 0 objects ################# ')
# print(self.files_B[index % len(self.files_B)])