-
Notifications
You must be signed in to change notification settings - Fork 15
/
zurich_raw2rgb_dataset.py
47 lines (35 loc) · 1.38 KB
/
zurich_raw2rgb_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import os
import cv2
class ZurichRAW2RGB(torch.utils.data.Dataset):
""" Canon RGB images from the "Zurich RAW to RGB mapping" dataset. You can download the full
dataset (22 GB) from http://people.ee.ethz.ch/~ihnatova/pynet.html#dataset. Alternatively, you can only download the
Canon RGB images (5.5 GB) from https://data.vision.ee.ethz.ch/bhatg/zurich-raw-to-rgb.zip
"""
def __init__(self, root, split='train'):
super().__init__()
if split in ['train', 'test']:
self.img_pth = os.path.join(root, split, 'canon')
else:
raise Exception('Unknown split {}'.format(split))
self.image_list = self._get_image_list(split)
def _get_image_list(self, split):
if split == 'train':
image_list = ['{:d}.jpg'.format(i) for i in range(46839)]
elif split == 'test':
image_list = ['{:d}.jpg'.format(i) for i in range(1204)]
else:
raise Exception
return image_list
def _get_image(self, im_id):
path = os.path.join(self.img_pth, self.image_list[im_id])
img = cv2.imread(path)
return img
def get_image(self, im_id):
frame = self._get_image(im_id)
return frame
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
frame = self._get_image(index)
return frame