-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
134 lines (111 loc) · 5.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# train.py
# SPDX-License-Identifier: MIT
# See COPYING file for more details.
import os
import torch
from data import train_dataloader
from utils import Adder, Timer, check_lr
from torch.utils.tensorboard import SummaryWriter
from valid import _valid
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
from warmup_scheduler import GradualWarmupScheduler
def _train(model, args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8)
dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker)
max_iter = len(dataloader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch, eta_min=1e-6)
scheduler.step()
epoch = 1
if args.resume:
state = torch.load(args.resume)
epoch = state['epoch']
optimizer.load_state_dict(state['optimizer'])
model.load_state_dict(state['model'])
print('Resume from %d'%epoch)
epoch += 1
writer = SummaryWriter('runs')
epoch_pixel_adder = Adder()
epoch_fft_adder = Adder()
iter_pixel_adder = Adder()
iter_fft_adder = Adder()
epoch_timer = Timer('m')
iter_timer = Timer('m')
best_psnr=-1
for epoch_idx in range(epoch, args.num_epoch + 1):
epoch_timer.tic()
iter_timer.tic()
for iter_idx, batch_data in enumerate(tqdm(dataloader)):
input_img, label_img = batch_data
input_img = input_img.to(device)
label_img = label_img.to(device)
optimizer.zero_grad()
pred_img = model(input_img)
label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear')
label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear')
l1 = criterion(pred_img[0], label_img4)
l2 = criterion(pred_img[1], label_img2)
l3 = criterion(pred_img[2], label_img)
loss_content = l1+l2+l3
label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1))
label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1)
pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1))
pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1)
label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1))
label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1)
pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1))
pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1)
label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1))
label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1)
pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1))
pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1)
f1 = criterion(pred_fft1, label_fft1)
f2 = criterion(pred_fft2, label_fft2)
f3 = criterion(pred_fft3, label_fft3)
loss_fft = f1+f2+f3
loss = loss_content + 0.1 * loss_fft
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001)
optimizer.step()
iter_pixel_adder(loss_content.item())
iter_fft_adder(loss_fft.item())
epoch_pixel_adder(loss_content.item())
epoch_fft_adder(loss_fft.item())
if (iter_idx + 1) % args.print_freq == 0:
output_string = "Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % (
iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(),
iter_fft_adder.average())
print(output_string)
with open('./log/train_log.txt','a') as f:
f.write(output_string+'\n')
writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter)
writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter)
iter_timer.tic()
iter_pixel_adder.reset()
iter_fft_adder.reset()
overwrite_name = os.path.join(args.model_save_dir, 'model.pkl')
torch.save({'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch_idx}, overwrite_name)
if epoch_idx % args.save_freq == 0:
save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx)
torch.save({'model': model.state_dict()}, save_name)
print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % (
epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average()))
epoch_fft_adder.reset()
epoch_pixel_adder.reset()
scheduler.step()
if epoch_idx % args.valid_freq == 0:
val = _valid(model, args, epoch_idx)
print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val))
with open('./log/train_psnr.txt','a') as f:
f.write('Epoch:{} PSNR:{}\n'.format(epoch_idx,val))
writer.add_scalar('PSNR', val, epoch_idx)
if val >= best_psnr:
torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl'))
best_psnr = val
save_name = os.path.join(args.model_save_dir, 'Final.pkl')
torch.save({'model': model.state_dict()}, save_name)