-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
110 lines (88 loc) · 4.11 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import os
from dataset import Dataset
from edar import EDAR
import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torchvision import transforms
from torchvision.models.vgg import vgg16
from utils import AverageMeter
from tqdm import tqdm
if __name__ == '__main__':
'''
It enables benchmark mode in cudnn.
benchmark mode is good whenever your input sizes for your network do not vary.
This way, cudnn will look for the optimal set of algorithms for that particular configuration (which takes some time).
This usually leads to faster runtime.
But if your input sizes changes at each iteration,
then cudnn will benchmark every time a new size appears,
possibly leading to worse runtime performances.
'''
cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser()
parser.add_argument('--images_dir', type=str, required=True)
parser.add_argument('--outputs_dir', type=str, required=True)
parser.add_argument('--jpeg_quality', type=int, default=40)
parser.add_argument('--patch_size', type=int, default=48)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_epochs', type=int, default=400)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--threads', type=int, default=1)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
opt = parser.parse_args()
if not os.path.exists(opt.outputs_dir):
os.makedirs(opt.outputs_dir)
torch.manual_seed(opt.seed)
transforms_train = transforms.Compose([transforms.ToTensor()])
model = EDAR().to(device)
print("Model loaded")
if opt.resume:
if os.path.isfile(opt.resume):
state_dict = model.state_dict()
for n, p in torch.load(opt.resume, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
print("Data processing started")
dataset = Dataset(opt.images_dir, opt.patch_size, opt.jpeg_quality,transforms=transforms_train)
dataloader = DataLoader(dataset=dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.threads,
pin_memory=True,
drop_last=True)
print("Data loading completed")
#vgg = vgg16(pretrained=True).cuda()
#loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
# for param in loss_network.parameters():
# param.requires_grad = False
for epoch in range(opt.num_epochs):
epoch_losses = AverageMeter()
print("Length of the dataset is", len(dataset))
with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:
_tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs))
for data in dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
#print(inputs.size(), labels.size())
outs = model(inputs)
loss = criterion(outs, labels)
#perception_loss = criterion(loss_network(outs), loss_network(labels))
#loss = loss + perception_loss*0.06
epoch_losses.update(loss.item(), len(inputs))
optimizer.zero_grad()
loss.backward()
optimizer.step()
_tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
_tqdm.update(len(inputs))
torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format("EDAR_", epoch)))