-
Notifications
You must be signed in to change notification settings - Fork 18
/
train.py
369 lines (289 loc) · 13.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""File description: Realize the model training function."""
import os
import shutil
import time
from enum import Enum
import torch
from torch import nn
from torch import optim
from torch.cuda import amp
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import config
from dataset import CUDAPrefetcher
from dataset import TrainValidImageDataset, TestImageDataset
from model import VDSR
def main():
# Initialize training to generate network evaluation indicators
best_psnr = 0.0
train_prefetcher, valid_prefetcher, test_prefetcher = load_dataset()
print("Load train dataset and valid dataset successfully.")
model = build_model()
print("Build VDSR model successfully.")
psnr_criterion, pixel_criterion = define_loss()
print("Define all loss functions successfully.")
optimizer = define_optimizer(model)
print("Define all optimizer functions successfully.")
scheduler = define_scheduler(optimizer)
print("Define all optimizer scheduler successfully.")
print("Check whether the pretrained model is restored...")
if config.resume:
# Load checkpoint model
checkpoint = torch.load(config.resume, map_location=lambda storage, loc: storage)
# Restore the parameters in the training node to this point
config.start_epoch = checkpoint["epoch"]
best_psnr = checkpoint["best_psnr"]
# Load checkpoint state dict. Extract the fitted model weights
model_state_dict = model.state_dict()
new_state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict}
# Overwrite the pretrained model weights to the current model
model_state_dict.update(new_state_dict)
model.load_state_dict(model_state_dict)
# Load the optimizer model
optimizer.load_state_dict(checkpoint["optimizer"])
# Load the scheduler model
scheduler.load_state_dict(checkpoint["scheduler"])
print("Loaded pretrained model weights.")
# Create a folder of super-resolution experiment results
samples_dir = os.path.join("samples", config.exp_name)
results_dir = os.path.join("results", config.exp_name)
if not os.path.exists(samples_dir):
os.makedirs(samples_dir)
if not os.path.exists(results_dir):
os.makedirs(results_dir)
# Create training process log file
writer = SummaryWriter(os.path.join("samples", "logs", config.exp_name))
# Initialize the gradient scaler
scaler = amp.GradScaler()
for epoch in range(config.start_epoch, config.epochs):
train(model, train_prefetcher, psnr_criterion, pixel_criterion, optimizer, epoch, scaler, writer)
_ = validate(model, valid_prefetcher, psnr_criterion, epoch, writer, "Valid")
psnr = validate(model, test_prefetcher, psnr_criterion, epoch, writer, "Test")
print("\n")
# Update lr
scheduler.step()
# Automatically save the model with the highest index
is_best = psnr > best_psnr
best_psnr = max(psnr, best_psnr)
torch.save({"epoch": epoch + 1,
"best_psnr": best_psnr,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict()},
os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"))
if is_best:
shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"), os.path.join(results_dir, "best.pth.tar"))
if (epoch + 1) == config.epochs:
shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"), os.path.join(results_dir, "last.pth.tar"))
def load_dataset() -> [DataLoader, DataLoader]:
# Load train, test and valid datasets
train_datasets = TrainValidImageDataset(config.train_image_dir, config.image_size, "Train")
valid_datasets = TrainValidImageDataset(config.valid_image_dir, config.image_size, "Valid")
test_datasets = TestImageDataset(config.test_image_dir, config.upscale_factor)
# Generator all dataloader
train_dataloader = DataLoader(train_datasets,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=True,
drop_last=True,
persistent_workers=True)
valid_dataloader = DataLoader(valid_datasets,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
drop_last=False,
persistent_workers=True)
test_dataloader = DataLoader(test_datasets,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=True,
drop_last=False,
persistent_workers=False)
# Place all data on the preprocessing data loader
train_prefetcher = CUDAPrefetcher(train_dataloader, config.device)
valid_prefetcher = CUDAPrefetcher(valid_dataloader, config.device)
test_prefetcher = CUDAPrefetcher(test_dataloader, config.device)
return train_prefetcher, valid_prefetcher, test_prefetcher
def build_model() -> nn.Module:
model = VDSR().to(config.device)
return model
def define_loss() -> [nn.MSELoss, nn.MSELoss]:
psnr_criterion = nn.MSELoss().to(config.device)
pixel_criterion = nn.MSELoss(reduction="sum").to(config.device)
return psnr_criterion, pixel_criterion
def define_optimizer(model) -> optim.SGD:
optimizer = optim.SGD(model.parameters(),
lr=config.model_lr,
momentum=config.model_momentum,
weight_decay=config.model_weight_decay,
nesterov=config.model_nesterov)
return optimizer
def define_scheduler(optimizer) -> lr_scheduler.StepLR:
scheduler = lr_scheduler.StepLR(optimizer, step_size=config.lr_scheduler_step_size, gamma=config.lr_scheduler_gamma)
return scheduler
def train(model, train_prefetcher, psnr_criterion, pixel_criterion, optimizer, epoch, scaler, writer) -> None:
# Calculate how many iterations there are under epoch
batches = len(train_prefetcher)
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":6.6f")
psnres = AverageMeter("PSNR", ":4.2f")
progress = ProgressMeter(batches, [batch_time, data_time, losses, psnres], prefix=f"Epoch: [{epoch + 1}]")
# Put the generator in training mode
model.train()
batch_index = 0
# Calculate the time it takes to test a batch of data
end = time.time()
# enable preload
train_prefetcher.reset()
batch_data = train_prefetcher.next()
while batch_data is not None:
# measure data loading time
data_time.update(time.time() - end)
lr = batch_data["lr"].to(config.device, non_blocking=True)
hr = batch_data["hr"].to(config.device, non_blocking=True)
# Initialize the generator gradient
model.zero_grad()
# Mixed precision training
with amp.autocast():
sr = model(lr)
loss = pixel_criterion(sr, hr)
# Gradient zoom + gradient clipping
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_gradient / optimizer.param_groups[0]["lr"], norm_type=2.0)
# Update generator weight
scaler.step(optimizer)
scaler.update()
# measure accuracy and record loss
psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr))
losses.update(loss.item(), lr.size(0))
psnres.update(psnr.item(), lr.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# Record training log information
if batch_index % config.print_frequency == 0:
# Writer Loss to file
writer.add_scalar("Train/Loss", loss.item(), batch_index + epoch * batches + 1)
progress.display(batch_index)
# Preload the next batch of data
batch_data = train_prefetcher.next()
# After a batch of data is calculated, add 1 to the number of batches
batch_index += 1
def validate(model, valid_prefetcher, psnr_criterion, epoch, writer, mode) -> float:
batch_time = AverageMeter("Time", ":6.3f", Summary.NONE)
psnres = AverageMeter("PSNR", ":4.2f", Summary.AVERAGE)
progress = ProgressMeter(len(valid_prefetcher), [batch_time, psnres], prefix=f"{mode}: ")
# Put the model in verification mode
model.eval()
batch_index = 0
# Calculate the time it takes to test a batch of data
end = time.time()
with torch.no_grad():
# enable preload
valid_prefetcher.reset()
batch_data = valid_prefetcher.next()
while batch_data is not None:
# measure data loading time
lr = batch_data["lr"].to(config.device, non_blocking=True)
hr = batch_data["hr"].to(config.device, non_blocking=True)
# Mixed precision
with amp.autocast():
sr = model(lr)
# measure accuracy and record loss
psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr))
psnres.update(psnr.item(), lr.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# Record training log information
if batch_index % config.print_frequency == 0:
progress.display(batch_index)
# Preload the next batch of data
batch_data = valid_prefetcher.next()
# After a batch of data is calculated, add 1 to the number of batches
batch_index += 1
# Print average PSNR metrics
progress.display_summary()
if mode == "Valid":
writer.add_scalar("Valid/PSNR", psnres.avg, epoch + 1)
elif mode == "Test":
writer.add_scalar("Test/PSNR", psnres.avg, epoch + 1)
else:
raise ValueError("Unsupported mode, please use `Valid` or `Test`.")
return psnres.avg
# Copy form "https://github.com/pytorch/examples/blob/master/imagenet/main.py"
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def summary(self):
if self.summary_type is Summary.NONE:
fmtstr = ""
elif self.summary_type is Summary.AVERAGE:
fmtstr = "{name} {avg:.2f}"
elif self.summary_type is Summary.SUM:
fmtstr = "{name} {sum:.2f}"
elif self.summary_type is Summary.COUNT:
fmtstr = "{name} {count:.2f}"
else:
raise ValueError(f"Invalid summary type {self.summary_type}")
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print("\t".join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(" ".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
if __name__ == "__main__":
main()