-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
307 lines (251 loc) · 12.2 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# -----------------------------------------------------------------------------
# Functions to create dataset and dataloader
# -----------------------------------------------------------------------------
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
from util import data_util
get_dataset_funcs = {
'cifar100_images': data_util.load_cifar100_images,
'imagenet100': data_util.load_imagenet100,
'tiny-imagenet-200': data_util.load_tiny_imagenet_200,
'caltech101': data_util.load_caltech101,
'aircraft': data_util.load_aircraft,
'scars': data_util.load_stanford_cars,
'cub200': data_util.load_CUB_200,
}
class create_ccd_dataset(Dataset):
'''
Input: dataset class and splitted data index list
Return: a new dataset class that consists only the splitted data considering CCD stage
where stage 0 is labelled data and stage > 0 is unlabelled data
'''
def __init__(self, dataset, transform, stage) -> None:
super(create_ccd_dataset, self).__init__()
self.dataset = dataset
self.transform = transform
self.batch_labeled_or_not = 1 if stage == 0 else 0
def __getitem__(self, index):
batch_data = cv2.imread(self.dataset['paths'][index])
batch_data = cv2.cvtColor(batch_data, cv2.COLOR_BGR2RGB)
batch_data = Image.fromarray(batch_data)
batch_label = self.dataset['labels'][index]
batch_unique_index = self.dataset['uq_idx'][index]
if self.transform is not None:
batch_data = self.transform(batch_data)
return batch_data, batch_label, batch_unique_index, np.array([self.batch_labeled_or_not])
def __len__(self):
return self.dataset['len']
class create_ccd_test_dataset(Dataset):
'''
Input: dataset class and splitted data index list
Return: a new dataset class that consists only the splitted data
'''
def __init__(self, unlabelled_dataset, labelled_dataset, transform) -> None:
super(create_ccd_test_dataset, self).__init__()
self.labelled_dataset = labelled_dataset
self.unlabelled_dataset = unlabelled_dataset
self.transform = transform
def __getitem__(self, index):
if index < self.labelled_dataset['len']:
batch_data = cv2.imread(self.labelled_dataset['paths'][index])
batch_data = cv2.cvtColor(batch_data, cv2.COLOR_BGR2RGB)
batch_data = Image.fromarray(batch_data)
batch_label = self.labelled_dataset['labels'][index]
batch_unique_index = self.labelled_dataset['uq_idx'][index]
batch_labeled_or_not = 1
else:
index = index - self.labelled_dataset['len']
batch_data = cv2.imread(self.unlabelled_dataset['paths'][index])
batch_data = cv2.cvtColor(batch_data, cv2.COLOR_BGR2RGB)
batch_data = Image.fromarray(batch_data)
batch_label = self.unlabelled_dataset['labels'][index]
batch_unique_index = self.unlabelled_dataset['uq_idx'][index]
batch_labeled_or_not = 0
if self.transform is not None:
batch_data = self.transform(batch_data)
return batch_data, batch_label, batch_unique_index, np.array([batch_labeled_or_not])
def __len__(self):
if self.unlabelled_dataset == None:
return self.labelled_dataset['len']
else:
return self.labelled_dataset['len'] + self.unlabelled_dataset['len']
def add_dataset_attribute(dataset):
dataset, index, transform = dataset
data_path_list = []
label_list = []
np.random.shuffle(index)
for i in index:
data_path_list.append(dataset[0][i])
label_list.append(dataset[1][i])
dataset = {
'paths': data_path_list,
'labels': label_list,
'uq_idx': index,
'len': len(index)
}
return dataset, transform, max(index)
def combined_dataset(dataset_list, use_gt, save_path, eval_on_train=False):
data_path_list, label_list, index = None, None, None
for idx, dataset in enumerate(dataset_list):
data_path_list = dataset['paths'] if data_path_list is None else data_path_list + dataset['paths']
if idx > 0 and use_gt == False:
# use predicted labels
print("use predicted labels from SS-K-means algorithm")
if eval_on_train:
predicted_label_path = open(os.path.join(save_path, 'pred_labels', f'pred_labels_stage_{idx}_train.txt'), 'r')
else:
predicted_label_path = open(os.path.join(save_path, 'pred_labels', f'pred_labels_stage_{idx}.txt'), 'r')
predicted_label_list = [int(x) for x in predicted_label_path.readlines()]
label_list = label_list + predicted_label_list
else:
# use ground truth labels for dataset from stage 0 or when use_gt is True
label_list = dataset['labels'] if label_list is None else label_list + dataset['labels']
index = dataset['uq_idx'] if index is None else np.concatenate((index, dataset['uq_idx']))
return {
'paths': data_path_list,
'labels': label_list,
'uq_idx': index,
'len': len(index)
}
def create_dataloader(args, dataset_i, stage_i):
'''
for stage == -1, dataloader containes dataset_val
for stage n > 0, dataloader contains dataset_train_i + rehearsal dataset
for stage == 0, dataloader contains dataset_train_labelled
'''
contrast_dataset_i = None
dataloader_dict, dataset_dict = {}, {}
# create dataloader for evaluation
if stage_i == -1:
dataset_i, transform, _ = add_dataset_attribute(dataset_i)
dataset_i = create_ccd_dataset(dataset_i, transform, stage=0) # set stage=0 to get labelled data
dataloader = torch.utils.data.DataLoader(
dataset_i,
batch_size=args.val_batch_size,
num_workers=args.val_workers,
pin_memory=args.pin_mem,
shuffle=False
)
return dataloader
# create dataloader for testing
if stage_i == -2:
dataset_i_list = []
for dataset_i_j in dataset_i:
dataset_i_j, transform, _ = add_dataset_attribute(dataset_i_j)
dataset_i_list.append(dataset_i_j)
unlabelled_val_dataset = None
if len(dataset_i) > 1:
if args.eval_version == 'ccd' and args.train == False:
unlabelled_val_dataset, labelled_val_dataset = dataset_i_list[-1], dataset_i_list[:len(dataset_i_list) - 1]
dataset_i = combined_dataset(
labelled_val_dataset,
args.use_gt_for_discovered_data,
args.save_path,
args.transductive_evaluation
)
elif args.eval_version == 'gcd' or args.train == True:
unlabelled_val_dataset, dataset_i = dataset_i_list[-1], dataset_i_list[0]
else: raise ValueError('Eval {} is not supported'.format(args.eval_version))
else:
dataset_i = dataset_i_list[0]
dataset_i = create_ccd_test_dataset(unlabelled_val_dataset, dataset_i, transform)
dataloader = torch.utils.data.DataLoader(
dataset_i,
batch_size=args.val_batch_size,
num_workers=args.val_workers,
shuffle=False
)
return dataloader
# create dataloader for training
elif stage_i >= 0:
dataset_i, transform, max_index = add_dataset_attribute(dataset_i)
if not args.use_strong_aug:
contrast_transform = data_util.ContrastiveLearningViewGenerator(base_transform=transform, n_views=args.n_views)
else:
strong_transform = data_util.get_strong_transform(args)
contrast_transform = data_util.StrongWeakView(strong_transform, transform)
# if stage_i == 0, create dataloader for labelled data, while for stage_i > 0, create dataloader for unlabelled data
dataset_i_ = create_ccd_dataset(dataset_i, transform, stage=stage_i)
contrast_dataset_i = create_ccd_dataset(dataset_i, contrast_transform, stage=stage_i)
if stage_i == 0:
# during stage 0, create sampler to balance the class distribution
sample_weights = [1 for i in range(len(dataset_i_))]
sample_weights = torch.DoubleTensor(sample_weights)
sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(dataset_i_))
else:
# during discovery stage we should not assume any class distribution, i.e., whether it is balanced or not
sampler = None
dataset_dict['default'] = dataset_i_
dataset_dict['contrast'] = contrast_dataset_i
for dataset in dataset_dict:
if dataset_dict[dataset] != None:
dataloader = torch.utils.data.DataLoader(
dataset_dict[dataset],
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=args.pin_mem,
drop_last=True,
sampler=sampler if dataset != 'default' and dataset != 'eval' else None,
)
dataloader_dict[dataset] = dataloader
return dataloader_dict
def build_CCD_dataset(args, split):
'''
Input: dataset configuration
Return: datasets for experimetnal task
Note: dataset_train index 0 contains labelled data, index [1, n_stage] contain mix data
'''
dataset_name = args.dataset
number_of_stage = args.n_stage
ccd_dataset = [None] * (number_of_stage)
datasets = list()
dataset_l = None
# Get dataset
try:
dataset_train, dataset_val = get_dataset_funcs[dataset_name](split)
except:
raise ValueError('Dataset {} not found.'.format(dataset_name))
if split == 'train':
transform = data_util.build_transform('default', args)
dataset = dataset_train
elif split == 'val':
transform = data_util.build_transform('test', args)
return (dataset_val, [x for x in range(len(dataset_val[1]))], transform)
elif split == 'test': # for inductive test dataset
transform = data_util.build_transform('test', args)
dataset = dataset_val
else:
raise ValueError('Split {} is not supported for CCD training.'.format(split))
# Get dataset classes
dataset_target = np.array(dataset[1])
for i in range(args.classes):
class_i = np.where(np.array(dataset_target == i))[0]
np.random.shuffle(class_i)
# If class index is in labelled dataset, split it according to args.ccd_split_ratio for every stage
if i < args.labelled_data:
class_i_l, class_i_u = np.split(class_i, [int(len(class_i) * args.ccd_split_ratio[0][0])])
dataset_l = class_i_l if dataset_l is None else np.concatenate((dataset_l, class_i_l), axis=0)
for stage_i in range(number_of_stage):
class_i_u, class_i_u_ = np.split(class_i_u, [int(len(class_i_u) * args.ccd_split_ratio[0][stage_i + 1])])
ccd_dataset[stage_i] = class_i_u if ccd_dataset[stage_i] is None else np.concatenate((ccd_dataset[stage_i], class_i_u), axis=0)
class_i_u = class_i_u_
else:
stage_selector = (i % args.labelled_data) // ((args.classes - args.labelled_data) // args.n_stage)
class_i_u = class_i
for stage_i in range(number_of_stage - stage_selector):
class_i_u, class_i_u_ = np.split(class_i_u, [int(len(class_i_u) * args.ccd_split_ratio[stage_selector + 1][stage_i])])
ccd_dataset[stage_i + stage_selector] = class_i_u if ccd_dataset[stage_i + stage_selector] is None else np.concatenate((ccd_dataset[stage_i + stage_selector], class_i_u), axis=0)
class_i_u = class_i_u_
# Create training dataset according to index list
# First index corresponds to labelled training dataset
dataset_l = (dataset, dataset_l, transform)
datasets.append(dataset_l)
# Second index to rest correspond to mix training dataset (Known and unknown data)
for dataset_u_i in ccd_dataset:
dataset_u_i = (dataset, dataset_u_i, transform)
datasets.append(dataset_u_i)
return datasets