-
Notifications
You must be signed in to change notification settings - Fork 1
/
datahandler.py
73 lines (65 loc) · 3.15 KB
/
datahandler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import numpy as np
from torch.utils.data.sampler import BatchSampler
class BalancedBatchSampler(BatchSampler):
"""
Returns batches of size n_classes * n_samples
"""
def __init__(self, labels, n_classes, n_samples):
if(torch.is_tensor(labels)):
labels = labels.numpy()
if(isinstance(labels, list)):
labels = np.array(labels)
self.labels_set = list(set(labels))
self.label_to_indices = {label: np.where(labels == label)[0]
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l])
self.used_label_indices_count = {label: 0 for label in self.labels_set}
self.count = 0
self.n_classes = n_classes
self.n_samples = n_samples
self.n_dataset = len(labels)
self.batch_size = self.n_samples * self.n_classes
def __iter__(self):
self.count = 0
while self.count + self.batch_size < self.n_dataset:
classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
indices = []
for class_ in classes:
indices.extend(self.label_to_indices[class_][
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples])
self.used_label_indices_count[class_] += self.n_samples
if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
np.random.shuffle(self.label_to_indices[class_])
self.used_label_indices_count[class_] = 0
yield indices
self.count += self.n_classes * self.n_samples
def __len__(self):
return self.n_dataset // self.batch_size
class BalancedDataLoader(torch.utils.data.DataLoader):
def __init__(self, dataset, batch_size=1, num_workers=0, collate_fn=None,
pin_memory=False, worker_init_fn=None):
targets = BalancedDataLoader.getTargets(dataset)
if(targets is None):
super().__init__(dataset, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, worker_init_fn=worker_init_fn)
else:
if(batch_size > len(targets)):
batch_size = len(targets)
if(torch.is_tensor(targets)):
targets = targets.cpu().numpy()
nclasses = len(set(targets))
sampler = BalancedBatchSampler(targets, nclasses, batch_size//nclasses)
super().__init__(dataset, num_workers=num_workers, batch_sampler=sampler,
collate_fn=collate_fn, pin_memory=pin_memory, worker_init_fn=worker_init_fn)
@staticmethod
def getTargets(dataset):
if(hasattr(dataset, 'y')):
return dataset.y
elif(hasattr(dataset, 'targets')):
return dataset.targets
if(isinstance(dataset, torch.utils.data.Subset)):
targets = BalancedDataLoader.getTargets(dataset.dataset)
return targets[dataset.indices]
return None