forked from jeya-maria-jose/TransWeather
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_data_functions_seq_batch.py
executable file
·119 lines (102 loc) · 4.48 KB
/
train_data_functions_seq_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from pkg_resources import invalid_marker
import torch.utils.data as data
from PIL import Image
from random import randrange, shuffle
from torchvision.transforms import Compose, ToTensor, Normalize
import re
from PIL import ImageFile
from os import path
import numpy as np
import torch
import glob, os, random
import torchvision.transforms.functional as TF
from tqdm import tqdm
from glob import glob
ImageFile.LOAD_TRUNCATED_IMAGES = True
# --- Training dataset --- #
class TrainData(data.Dataset):
def __init__(self, crop_size, train_data_dir, gt_dir, rain_dir, sequences):
super().__init__()
assert isinstance(gt_dir, list)
assert isinstance(rain_dir, list)
assert isinstance(sequences, list)
assert len(gt_dir) == 1
gt_dir = gt_dir[0]
print("Reading Data...")
self.gt_names, self.input_names, self.seq_len = self.getAllImageNames(
train_data_dir, gt_dir, rain_dir, sequences)
print("====> [INFO] Total number of training data: ", len(self.gt_names))
print("====> [INFO] Training sequence points: ", self.seq_len)
self.crop_size = crop_size
self.train_data_dir = train_data_dir
self.seq_index = 0
self.aug = random.randint(0, 5)
def getAllImageNames(self, train_data_dir, gt_dir, rain_dir, sequences):
gt_imgs = []
rain_imgs = []
seq_len = []
seq_count = 0
for seq_ind in range(len(sequences)):
seq = sequences[seq_ind]
train_gt_dir = os.path.join(train_data_dir, seq + "_" + gt_dir)
gt_img_names = glob(os.path.join(train_gt_dir, "*.jpg"))
for gt_img_n in gt_img_names:
rain_sub_img_names = []
for rain_sub in rain_dir:
train_rain_sub_dir = os.path.join(
train_data_dir, seq + "_" + rain_sub)
rain_sub_img_names.append(gt_img_n.replace(train_gt_dir, train_rain_sub_dir))
rain_imgs.append(rain_sub_img_names)
seq_count += len(gt_img_names)
seq_len.append(seq_count)
gt_imgs.extend(gt_img_names)
assert len(gt_imgs) == len(rain_imgs)
assert len(gt_imgs) == seq_len[-1]
return gt_imgs, rain_imgs, seq_len
def get_images(self, index):
input_names = self.input_names[index]
gt_name = self.gt_names[index]
input_imgs = [Image.open(input_name) for input_name in input_names]
try:
gt_img = Image.open(gt_name)
except:
gt_img = Image.open(gt_name).convert('RGB')
gt_imgs = [gt_img] * len(input_imgs)
# --- Transform to tensor --- #
transform_input = Compose([ToTensor()])
transform_gt = Compose([ToTensor()])
input_ims = [transform_input(input_img) for input_img in input_imgs]
gts = [transform_gt(gt_img) for gt_img in gt_imgs]
########### --- NOTE: data augmentation --- ###############
if self.aug == 1:
input_ims = [TF.hflip(input_im) for input_im in input_ims]
gts = [TF.hflip(gt) for gt in gts]
elif self.aug == 2:
input_ims = [TF.vflip(input_im) for input_im in input_ims]
gts = [TF.vflip(gt) for gt in gts]
elif self.aug == 3:
input_ims = [TF.rotate(input_im, 90) for input_im in input_ims]
gts = [TF.rotate(gt, 90) for gt in gts]
if self.aug == 4:
input_ims = [TF.rotate(input_im, 180) for input_im in input_ims]
gts = [TF.rotate(gt, 180) for gt in gts]
elif self.aug == 5:
input_ims = [TF.rotate(input_im, 270) for input_im in input_ims]
gts = [TF.rotate(gt, 270) for gt in gts]
###########################################################
# --- Normalize the input image --- #
normalize_input = Compose([Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
input_ims = [normalize_input(input_im) for input_im in input_ims]
is_continue = True
if index >= self.seq_len[self.seq_index]:
self.seq_index += 1
is_continue = False
self.seq_index = 0 if self.seq_index >= len(self.seq_len) else self.seq_index
if not is_continue:
self.aug = random.randint(0, 5)
return input_ims, gts, is_continue
def __getitem__(self, index):
res = self.get_images(index)
return res
def __len__(self):
return len(self.input_names)