-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
76 lines (59 loc) · 2.34 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from torch.utils.data import DataLoader, Dataset
import numpy as np
import io
from PIL import Image
import pandas as pd
import os.path as osp
class NormalDataset(Dataset):
def __init__(self, root_dir, meta_file, is_train=True, args=None, transform=None):
self.root_dir = root_dir
if not self.root_dir.startswith("/data"):
self.root_dir = osp.join(osp.expanduser('~'), self.root_dir)
self.transform = transform
print("building dataset from %s"%meta_file)
self.metas = pd.read_csv(meta_file, sep=" ",header=None)
print("read meta done")
if is_train==True:
total_num = args.max_iter*args.batch_size
self.metas = self.metas.sample(total_num, replace=True)
self.metas = self.metas.reset_index()
self.num = len(self.metas)
def __len__(self):
return self.num
def __getitem__(self, idx):
filename = osp.join(self.root_dir, self.metas.ix[idx, 0])
label = self.metas.ix[idx, 1]
## memcached
img = Image.open(filename).convert('RGB')
#img = np.zeros((350, 350, 3), dtype=np.uint8)
#img = Image.fromarray(img)
#cls = 0
## transform
if self.transform is not None:
img = self.transform(img)
return img, label
class TeacherDataset(Dataset):
def __init__(self, root_dir, meta_file, transform=None, args=None, is_train=True):
self.root_dir = root_dir
if not self.root_dir.startswith("/data"):
self.root_dir = osp.join(osp.expanduser('~'), self.root_dir)
self.transform = transform
metas = pd.read_csv(meta_file, sep=" ", header=None)
if args != None and is_train==True:
metas = metas.sample(args.max_iter*args.batch_size, replace=True)
metas = metas.reset_index()
self.metas = metas
print("building dataset from %s"%meta_file)
self.num = len(self.metas)
print("read meta done")
def __len__(self):
return self.num
def __getitem__(self, idx):
filename = self.root_dir + '/' + self.metas.ix[idx,0]
cls = self.metas.ix[idx,1]
## memcached
img = Image.open(filename).convert('RGB')
## transform
if self.transform is not None:
img1, img2 = self.transform(img)
return img1, img2, cls