-
Notifications
You must be signed in to change notification settings - Fork 0
/
select_dataset.py
48 lines (37 loc) · 1.31 KB
/
select_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import copy
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
class SelectDataset(Dataset):
def __init__(self, dataset, pre_idx=None, transform=None):
self.root = dataset.root
self.train = dataset.train
self.target_transform = dataset.target_transform
if pre_idx is None:
self.data = copy.deepcopy(dataset.data)
self.targets = np.array(copy.deepcopy(dataset.targets))
else:
self.data = copy.deepcopy(dataset.data[pre_idx])
self.targets = np.array(copy.deepcopy(dataset.targets))[pre_idx]
if transform is None:
self.transform = dataset.transform
else:
self.transform = transform
self.channels, self.width, self.height = self.__shape_info__()
def __getitem__(self, item):
img = self.data[item]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
label_idx = self.targets[item]
label = np.zeros(10)
label[label_idx] = 1
label = torch.Tensor(label)
return img, label
def __len__(self):
return len(self.data)
def __shape_info__(self):
return self.data.shape[1:]