-
Notifications
You must be signed in to change notification settings - Fork 11
/
denoise_grayscale_pwcnet.py
106 lines (88 loc) · 5.35 KB
/
denoise_grayscale_pwcnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright (c) 2021 Huawei Technologies Co., Ltd.
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
#
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.optim as optim
import dataset as datasets
from data import processing, sampler, DataLoader
import models.deeprep.deeprepnet as deeprep_nets
import actors.deeprep_actors as deeprep_actors
from trainers import SimpleTrainer
import data.transforms as tfm
from admin.multigpu import MultiGPU
from models.loss.image_quality_v2 import PSNR, PixelWiseError
def run(settings):
settings.description = 'Default parameters for training Deep Reparametrization model for grayscale denoising, using' \
'pre-trained optical flow model PWC-Net'
settings.batch_size = 8
settings.num_workers = 8
settings.multi_gpu = False
settings.print_interval = 1
settings.crop_sz = 128
settings.burst_sz = 8
settings.pre_downsample_factor = 4
settings.max_jitter_small = 2
settings.max_jitter_large = 16
openimages_train = datasets.OpenImagesDataset(split='train')
zurich_val = datasets.ZurichRAW2RGB(split='test')
transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.0, normalize=True), tfm.RandomHorizontalFlip())
transform_val = tfm.Transform(tfm.ToTensorAndJitter(0.0, normalize=True))
data_processing_train = processing.DenoisingProcessing(crop_sz=settings.crop_sz,
burst_size=settings.burst_sz,
pre_downsample_factor=settings.pre_downsample_factor,
max_jitter_small=settings.max_jitter_small,
max_jitter_large=settings.max_jitter_large,
transform=transform_train,
return_grayscale=True)
data_processing_val = processing.DenoisingProcessing(crop_sz=settings.crop_sz,
burst_size=settings.burst_sz,
pre_downsample_factor=settings.pre_downsample_factor,
max_jitter_small=settings.max_jitter_small,
max_jitter_large=settings.max_jitter_large,
transform=transform_val,
return_grayscale=False)
# Train sampler and loader
dataset_train = sampler.RandomImage([openimages_train], [1],
samples_per_epoch=settings.batch_size * 1000, processing=data_processing_train)
dataset_val = sampler.IndexedImage(zurich_val, processing=data_processing_val)
loader_train = DataLoader('train', dataset_train, training=True, num_workers=settings.num_workers,
stack_dim=0, batch_size=settings.batch_size)
loader_val = DataLoader('val', dataset_val, training=False, num_workers=settings.num_workers,
stack_dim=0, batch_size=settings.batch_size, epoch_interval=5)
net = deeprep_nets.deeprep_denoise_iccv21(num_iter=3, enc_dim=32, enc_num_res_blocks=4,
enc_out_dim=64,
dec_dim_pre=64,
dec_num_res_blocks=9,
dec_in_dim=16,
use_noise_estimate=True,
wp_project_dim=16,
wp_offset_feat_dim=8,
wp_num_offset_feat_extractor_res=0,
wp_num_weight_predictor_res=1,
color_input=False
)
# Wrap the network for multi GPU training
if settings.multi_gpu:
net = MultiGPU(net, dim=0)
objective = {
'rgb': PixelWiseError(metric='l1', boundary_ignore=4),
'psnr': PSNR(boundary_ignore=4),
}
loss_weight = {
'rgb': 1.0,
}
actor = deeprep_actors.DeepRepDenoisingActor(net=net, objective=objective, loss_weight=loss_weight)
optimizer = optim.Adam([{'params': actor.net.parameters(), 'lr': 1e-4}],
lr=2e-4)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.2)
trainer = SimpleTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler)
trainer.train(150, load_latest=True, fail_safe=True)