-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_PRN_r.py
118 lines (96 loc) · 4.81 KB
/
train_PRN_r.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import argparse
import paddle
from paddle.io import DataLoader
from visualdl import LogWriter
from DerainDataset import *
from utils import *
import paddle.optimizer.lr.MultiStepDecay as MultiStepLR
from SSIM import SSIM
from networks import *
parser = argparse.ArgumentParser(description="PReNet_train")
parser.add_argument("--preprocess", type=bool, default=False, help='run prepare_data or not')
parser.add_argument("--batch_size", type=int, default=18, help="Training batch size")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--milestone", type=int, default=[30,50,80], help="When to decay learning rate")
parser.add_argument("--lr", type=float, default=1e-3, help="initial learning rate")
parser.add_argument("--save_path", type=str, default="logs/PReNet_test", help='path to save models and log files')
parser.add_argument("--save_freq",type=int,default=1,help='save intermediate model')
parser.add_argument("--data_path",type=str, default="datasets/train/RainTrainL",help='path to training data')
parser.add_argument("--recurrent_iter", type=int, default=6, help='number of recursive stages')
opt = parser.parse_args()
def main():
print('Loading dataset ...\n')
dataset_train = Dataset(data_path=opt.data_path)
loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batch_size, shuffle=True)
print("# of training samples: %d\n" % int(len(dataset_train)))
# Build model
model = PRN_r(recurrent_iter=opt.recurrent_iter)
print_network(model)
# loss function
criterion = SSIM()
# Optimizer
scheduler = MultiStepLR(learning_rate=opt.lr, milestones=opt.milestone, gamma=0.2) # learning rates
optimizer = paddle.optimizer.Adam(model.parameters(), learning_rate=scheduler)
# record training
writer = LogWriter(opt.save_path)
# load the lastest model
initial_epoch = findLastCheckpoint(save_dir=opt.save_path)
if initial_epoch > 0:
print('resuming by loading epoch %d' % initial_epoch)
model.load_state_dict(paddle.load(os.path.join(opt.save_path, 'net_epoch%d.pdparams' % initial_epoch)))
# start training
step = 0
for epoch in range(initial_epoch, opt.epochs):
scheduler.step(epoch)
for param_group in optimizer.param_groups:
print('learning rate %f' % param_group["lr"])
## epoch training start
for i, (input_train, target_train) in enumerate(loader_train, 0):
model.train()
optimizer.clear_grad()
out_train, _ = model(input_train)
pixel_metric = criterion(target_train, out_train)
loss = -pixel_metric
loss.backward()
optimizer.step()
# training curve
model.eval()
out_train, _ = model(input_train)
out_train = paddle.clip(out_train, 0., 1.)
psnr_train = batch_PSNR(out_train, target_train, 1.)
print("[epoch %d][%d/%d] loss: %.4f, pixel_metric: %.4f, PSNR: %.4f" %
(epoch+1, i+1, len(loader_train), loss.item(), pixel_metric.item(), psnr_train))
if step % 10 == 0:
# Log the scalar values
writer.add_scalar('loss', loss.item(), step)
writer.add_scalar('PSNR on training data', psnr_train, step)
step += 1
## epoch training end
# log the images
model.eval()
out_train, _ = model(input_train)
out_train = paddle.clip(out_train, 0., 1.)
'''
im_target = utils.make_grid(target_train.data, nrow=8, normalize=True, scale_each=True)
im_input = utils.make_grid(input_train.data, nrow=8, normalize=True, scale_each=True)
im_derain = utils.make_grid(out_train.data, nrow=8, normalize=True, scale_each=True)
writer.add_image('clean image', im_target, epoch+1)
writer.add_image('rainy image', im_input, epoch+1)
writer.add_image('deraining image', im_derain, epoch+1)
'''
# save model
paddle.save(model.state_dict(), os.path.join(opt.save_path, 'net_latest.pdparams'))
if epoch % opt.save_freq == 0:
paddle.save(model.state_dict(), os.path.join(opt.save_path, 'net_epoch%d.pdparams' % (epoch+1)))
if __name__ == "__main__":
if opt.preprocess:
if opt.data_path.find('RainTrainH') != -1:
prepare_data_RainTrainH(data_path=opt.data_path, patch_size=100, stride=80)
elif opt.data_path.find('RainTrainL') != -1:
prepare_data_RainTrainL(data_path=opt.data_path, patch_size=100, stride=80)
elif opt.data_path.find('Rain12600') != -1:
prepare_data_Rain12600(data_path=opt.data_path, patch_size=100, stride=100)
else:
print('unkown datasets: please define prepare data function in DerainDataset.py')
main()