-
Notifications
You must be signed in to change notification settings - Fork 23
/
data_provider.py
84 lines (73 loc) · 2.9 KB
/
data_provider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import os
import numpy as np
from PIL import Image
class Video_Provider(Dataset):
def __init__(self, base_path, txt_file, im_size=96, frames=5):
super(Video_Provider, self).__init__()
self.base_path = base_path
self.txt_file = open(txt_file, 'r').readlines()
self.im_size = im_size
self.trans = transforms.ToTensor()
self.frames = frames
def _get_file_name(self, index):
"""
Read consecutive frames within index-th data starting at the start-th frame
:param index: number of video in dataset
:return:
"""
res = []
start = np.random.randint(1, 8-self.frames)
for i in range(start, start+self.frames):
res.append(os.path.join(self.base_path, self.txt_file[index].strip(), 'im{}.png'.format(i)))
return res
@staticmethod
def _get_random_sigma():
r = np.random.rand()
return 10 ** (1.5*r - 2)
def _get_crop_h_w(self):
h = np.random.randint(0, 256 - self.im_size + 1)
w = np.random.randint(0, 448 - self.im_size + 1)
return h, w
def __getitem__(self, index):
img_files = self._get_file_name(index)
if not self.im_size is None:
hs, ws = self._get_crop_h_w()
gt = torch.zeros(3, self.im_size, self.im_size)
noised = torch.zeros(self.frames+1, 3, self.im_size, self.im_size)
sigma = self._get_random_sigma()
for i, file in enumerate(img_files):
img = Image.open(file)
img = self.trans(img)[:, hs:hs+self.im_size, ws:ws+self.im_size]
if i == self.frames//2:
gt = img
noised[i, ...] = torch.clamp(img + sigma*torch.randn_like(img), 0.0, 1.0)
noised[-1, ...] = sigma * torch.ones_like(gt)
else:
sigma = self._get_random_sigma()
noised = []
for i, file in enumerate(img_files):
img = Image.open(file)
img = self.trans(img)
if i == self.frames//2:
gt = img
noised.append(torch.clamp(img + sigma*torch.randn_like(img), 0.0, 1.0))
noised.append(sigma*torch.ones_like(gt))
noised = torch.stack(noised, dim=0)
return noised, gt
def __len__(self):
return len(self.txt_file)
if __name__ == '__main__':
dataset = Video_Provider(
'H:/vimeo_septuplet/sequences',
'H:/vimeo_septuplet/sep_trainlist.txt',
im_size=256
)
tran = transforms.ToPILImage()
for index, (data, gt) in enumerate(dataset):
# for i in range(6):
# tran(data[i, ...]).save('{}_noisy_{}.png'.format(index, i), quality=100)
# tran(gt).save('{}_gt.png'.format(index), quality=100)
print(index)