-
Notifications
You must be signed in to change notification settings - Fork 31
/
datasets.py
100 lines (82 loc) · 3.55 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from PIL import Image
import numpy as np
def default_loader(path):
try:
img = Image.open(path).convert('RGB')
except:
with open('read_error.txt', 'a') as fid:
fid.write(path+'\n')
return Image.new('RGB', (224,224), 'white')
return img
class RandomDataset(Dataset):
def __init__(self, transform=None, dataloader=default_loader):
self.transform = transform
self.dataloader = dataloader
with open('/home/pqzhuang/data/CUB/CUB_200_2011/val.txt', 'r') as fid:
self.imglist = fid.readlines()
def __getitem__(self, index):
image_name, label = self.imglist[index].strip().split()
image_path = image_name
img = self.dataloader(image_path)
img = self.transform(img)
label = int(label)
label = torch.LongTensor([label])
return [img, label]
def __len__(self):
return len(self.imglist)
class BatchDataset(Dataset):
def __init__(self, transform=None, dataloader=default_loader):
self.transform = transform
self.dataloader = dataloader
with open('/home/pqzhuang/data/CUB/CUB_200_2011/train.txt', 'r') as fid:
self.imglist = fid.readlines()
self.labels = []
for line in self.imglist:
image_path, label = line.strip().split()
self.labels.append(int(label))
self.labels = np.array(self.labels)
self.labels = torch.LongTensor(self.labels)
def __getitem__(self, index):
image_name, label = self.imglist[index].strip().split()
image_path = image_name
img = self.dataloader(image_path)
img = self.transform(img)
label = int(label)
label = torch.LongTensor([label])
return [img, label]
def __len__(self):
return len(self.imglist)
class BalancedBatchSampler(BatchSampler):
def __init__(self, dataset, n_classes, n_samples):
self.labels = dataset.labels
self.labels_set = list(set(self.labels.numpy()))
self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l])
self.used_label_indices_count = {label: 0 for label in self.labels_set}
self.count = 0
self.n_classes = n_classes
self.n_samples = n_samples
self.dataset = dataset
self.batch_size = self.n_samples * self.n_classes
def __iter__(self):
self.count = 0
while self.count + self.batch_size < len(self.dataset):
classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
indices = []
for class_ in classes:
indices.extend(self.label_to_indices[class_][
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples])
self.used_label_indices_count[class_] += self.n_samples
if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
np.random.shuffle(self.label_to_indices[class_])
self.used_label_indices_count[class_] = 0
yield indices
self.count += self.n_classes * self.n_samples
def __len__(self):
return len(self.dataset) // self.batch_size