This repository has been archived by the owner on Apr 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
/
datasets.py
110 lines (90 loc) · 3.03 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import sys
import glob
from os.path import join
import numpy as np
from PIL import Image, ImageEnhance
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.transforms as T
import config as c
from natsort import natsorted
def to_rgb(image):
rgb_image = Image.new("RGB", image.size)
rgb_image.paste(image)
return rgb_image
class Hinet_Dataset(Dataset):
def __init__(self, transforms_=None, mode="train"):
self.transform = transforms_
self.mode = mode
if self.mode == "train":
# TRAIN SETTING
if c.Dataset_mode == 'DIV2K':
self.TRAIN_PATH = c.TRAIN_PATH_DIV2K
self.format_train = 'png'
print('TRAIN DATASET is DIV2K')
if c.Dataset_mode == 'COCO':
self.TRAIN_PATH = c.TEST_PATH_COCO
self.format_train = 'jpg'
print('TRAIN DATASET is COCO')
# train
self.files = natsorted(sorted(glob.glob(self.TRAIN_PATH + "/*." + self.format_train)))
if self.mode == "val":
# VAL SETTING
if c.Dataset_VAL_mode == 'DIV2K':
self.VAL_PATH = c.VAL_PATH_DIV2K
self.format_val = 'png'
print('VAL DATASET is DIV2K')
if c.Dataset_VAL_mode == 'COCO':
self.VAL_PATH = c.VAL_PATH_COCO
self.format_val = 'jpg'
print('VAL DATASET is COCO')
if c.Dataset_VAL_mode == 'ImageNet':
self.VAL_PATH = c.VAL_PATH_IMAGENET
self.format_val = 'JPEG'
print('VAL DATASET is ImageNet')
# test
self.files = sorted(glob.glob(self.VAL_PATH + "/*." + self.format_val))
def __getitem__(self, index):
try:
image = Image.open(self.files[index])
image = to_rgb(image)
item = self.transform(image)
return item
except:
return self.__getitem__(index + 1)
def __len__(self):
return len(self.files)
if c.Dataset_VAL_mode == 'DIV2K':
cropsize_val = c.cropsize_val_div2k
if c.Dataset_VAL_mode == 'COCO':
cropsize_val = c.cropsize_val_coco
if c.Dataset_VAL_mode == 'ImageNet':
cropsize_val = c.cropsize_val_imagenet
transform = T.Compose([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
T.RandomCrop(c.cropsize),
T.ToTensor()
])
transform_val = T.Compose([
T.CenterCrop(cropsize_val),
T.ToTensor(),
])
# Training data loader
trainloader = DataLoader(
Hinet_Dataset(transforms_=transform, mode="train"),
batch_size=c.batch_size,
shuffle=True,
pin_memory=False,
num_workers=1,
drop_last=True
)
# Test data loader
testloader = DataLoader(
Hinet_Dataset(transforms_=transform_val, mode="val"),
batch_size=c.batchsize_val,
shuffle=False,
pin_memory=False,
num_workers=1,
drop_last=True
)