Skip to content

Commit

Permalink
add LDL loss (#523)
Browse files Browse the repository at this point in the history
* add weighting func

* add LDL loss

* add weighting func

* add LDL loss

* add LDL loss

* fit formats

* fix format

* adapt variables

* adapt variables

* minor

Co-authored-by: Xintao <wxt1994@126.com>
  • Loading branch information
csjliang and xinntao authored May 9, 2022
1 parent 479ec97 commit 5425075
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 0 deletions.
50 changes: 50 additions & 0 deletions basicsr/losses/loss_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import torch
from torch.nn import functional as F


Expand Down Expand Up @@ -93,3 +94,52 @@ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
return loss

return wrapper


def get_local_weights(residual, ksize):
"""Get local weights for generating the artifact map of LDL.
It is only called by the `get_refined_artifact_map` function.
Args:
residual (Tensor): Residual between predicted and ground truth images.
ksize (Int): size of the local window.
Returns:
Tensor: weight for each pixel to be discriminated as an artifact pixel
"""

pad = (ksize - 1) // 2
residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')

unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)

return pixel_level_weight


def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
"""Calculate the artifact map of LDL
(Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
Args:
img_gt (Tensor): ground truth images.
img_output (Tensor): output images given by the optimizing model.
img_ema (Tensor): output images given by the ema model.
ksize (Int): size of the local window.
Returns:
overall_weight: weight for each pixel to be discriminated as an artifact pixel
(calculated based on both local and global observations).
"""

residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)

patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
overall_weight = patch_level_weight * pixel_level_weight

overall_weight[residual_sr < residual_ema] = 0

return overall_weight
8 changes: 8 additions & 0 deletions basicsr/models/realesrgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from basicsr.data.transforms import paired_random_crop
from basicsr.losses.loss_util import get_refined_artifact_map
from basicsr.models.srgan_model import SRGANModel
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
Expand Down Expand Up @@ -207,6 +208,8 @@ def optimize_parameters(self, current_iter):

self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
if self.cri_ldl:
self.output_ema = self.net_g_ema(self.lq)

l_g_total = 0
loss_dict = OrderedDict()
Expand All @@ -216,6 +219,11 @@ def optimize_parameters(self, current_iter):
l_g_pix = self.cri_pix(self.output, l1_gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
if self.cri_ldl:
pixel_weight = get_refined_artifact_map(self.gt, self.output, self.output_ema, 7)
l_g_ldl = self.cri_ldl(torch.mul(pixel_weight, self.output), torch.mul(pixel_weight, self.gt))
l_g_total += l_g_ldl
loss_dict['l_g_ldl'] = l_g_ldl
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
Expand Down
5 changes: 5 additions & 0 deletions basicsr/models/srgan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def init_training_settings(self):
else:
self.cri_pix = None

if train_opt.get('ldl_opt'):
self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device)
else:
self.cri_ldl = None

if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
Expand Down
182 changes: 182 additions & 0 deletions options/train/LDL/train_LDL_Real_x4.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# general settings
name: train_LDL_realworld_RRDB
model_type: RealESRGANModel
scale: 4
num_gpu: 4
manual_seed: 0

# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
# USM the ground-truth
l1_gt_usm: True
percep_gt_usm: True
gan_gt_usm: False

# the first degradation process
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
resize_range: [0.15, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 30]
poisson_scale_range: [0.05, 3]
gray_noise_prob: 0.4
jpeg_range: [30, 95]

# the second degradation process
second_blur_prob: 0.8
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
resize_range2: [0.3, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 25]
poisson_scale_range2: [0.05, 2.5]
gray_noise_prob2: 0.4
jpeg_range2: [30, 95]

gt_size: 256
queue_size: 180

# dataset and data loader settings
datasets:
train:
name: DF2K+OST
type: RealESRGANDataset
dataroot_gt: datasets/DF2K
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
io_backend:
type: disk

blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 0.1
blur_sigma: [0.2, 3]
betag_range: [0.5, 4]
betap_range: [1, 2]

blur_kernel_size2: 21
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 0.1
blur_sigma2: [0.2, 1.5]
betag_range2: [0.5, 4]
betap_range2: [1, 2]

final_sinc_prob: 0.8

gt_size: 256
use_hflip: True
use_rot: False

# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~

val:
name: RealWorld38
type: SingleImageDataset
dataroot_lq: datasets/RealWorld38/LR
io_backend:
type: disk

# network structures
network_g:
type: RRDBNet
num_in_ch: 3
num_out_ch: 3
num_feat: 64
num_block: 23
num_grow_ch: 32


network_d:
type: UNetDiscriminatorSN
num_in_ch: 3
num_feat: 64
skip_connection: True

# path
path:
# use the pre-trained Real-ESRNet model
pretrain_network_g: experiments/pretrained_models/RealESRGAN/RealESRNet_x4plus.pth
param_key_g: params_ema
strict_load_g: true
resume_state: ~

# training settings
train:
ema_decay: 0.999
optim_g:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]

scheduler:
type: MultiStepLR
milestones: [400000]
gamma: 0.5

total_iter: 400000
warmup_iter: -1 # no warm up

# losses
pixel_opt:
type: L1Loss
loss_weight: !!float 1e-2
reduction: mean
ldl_opt:
type: L1Loss
loss_weight: !!float 1.0
reduction: mean
# perceptual loss (content and style losses)
perceptual_opt:
type: PerceptualLoss
layer_weights:
# before relu
'conv1_2': 0.1
'conv2_2': 0.1
'conv3_4': 1
'conv4_4': 1
'conv5_4': 1
vgg_type: vgg19
use_input_norm: true
perceptual_weight: !!float 1.0
style_weight: 0
range_norm: false
criterion: l1
# gan loss
gan_opt:
type: GANLoss
gan_type: vanilla
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 1e-1

net_d_iters: 1
net_d_init_iters: 0

# Uncomment these for validation
# validation settings
val:
val_freq: !!float 5e3
save_img: True

# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~

# dist training settings
dist_params:
backend: nccl
port: 29500

0 comments on commit 5425075

Please sign in to comment.