From 542507534c3e959f7598e6f74687738a930b9b7a Mon Sep 17 00:00:00 2001 From: csjliang <67353604+csjliang@users.noreply.github.com> Date: Mon, 9 May 2022 15:32:07 +0800 Subject: [PATCH] add LDL loss (#523) * add weighting func * add LDL loss * add weighting func * add LDL loss * add LDL loss * fit formats * fix format * adapt variables * adapt variables * minor Co-authored-by: Xintao --- basicsr/losses/loss_util.py | 50 +++++++ basicsr/models/realesrgan_model.py | 8 ++ basicsr/models/srgan_model.py | 5 + options/train/LDL/train_LDL_Real_x4.yml | 182 ++++++++++++++++++++++++ 4 files changed, 245 insertions(+) create mode 100644 options/train/LDL/train_LDL_Real_x4.yml diff --git a/basicsr/losses/loss_util.py b/basicsr/losses/loss_util.py index 744eeb46d..fd293ff9e 100644 --- a/basicsr/losses/loss_util.py +++ b/basicsr/losses/loss_util.py @@ -1,4 +1,5 @@ import functools +import torch from torch.nn import functional as F @@ -93,3 +94,52 @@ def wrapper(pred, target, weight=None, reduction='mean', **kwargs): return loss return wrapper + + +def get_local_weights(residual, ksize): + """Get local weights for generating the artifact map of LDL. + + It is only called by the `get_refined_artifact_map` function. + + Args: + residual (Tensor): Residual between predicted and ground truth images. + ksize (Int): size of the local window. + + Returns: + Tensor: weight for each pixel to be discriminated as an artifact pixel + """ + + pad = (ksize - 1) // 2 + residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') + + unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) + pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) + + return pixel_level_weight + + +def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): + """Calculate the artifact map of LDL + (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) + + Args: + img_gt (Tensor): ground truth images. + img_output (Tensor): output images given by the optimizing model. + img_ema (Tensor): output images given by the ema model. + ksize (Int): size of the local window. + + Returns: + overall_weight: weight for each pixel to be discriminated as an artifact pixel + (calculated based on both local and global observations). + """ + + residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) + residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) + + patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) + pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) + overall_weight = patch_level_weight * pixel_level_weight + + overall_weight[residual_sr < residual_ema] = 0 + + return overall_weight diff --git a/basicsr/models/realesrgan_model.py b/basicsr/models/realesrgan_model.py index eb0ec1ca7..c74b28fb1 100644 --- a/basicsr/models/realesrgan_model.py +++ b/basicsr/models/realesrgan_model.py @@ -6,6 +6,7 @@ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt from basicsr.data.transforms import paired_random_crop +from basicsr.losses.loss_util import get_refined_artifact_map from basicsr.models.srgan_model import SRGANModel from basicsr.utils import DiffJPEG, USMSharp from basicsr.utils.img_process_util import filter2D @@ -207,6 +208,8 @@ def optimize_parameters(self, current_iter): self.optimizer_g.zero_grad() self.output = self.net_g(self.lq) + if self.cri_ldl: + self.output_ema = self.net_g_ema(self.lq) l_g_total = 0 loss_dict = OrderedDict() @@ -216,6 +219,11 @@ def optimize_parameters(self, current_iter): l_g_pix = self.cri_pix(self.output, l1_gt) l_g_total += l_g_pix loss_dict['l_g_pix'] = l_g_pix + if self.cri_ldl: + pixel_weight = get_refined_artifact_map(self.gt, self.output, self.output_ema, 7) + l_g_ldl = self.cri_ldl(torch.mul(pixel_weight, self.output), torch.mul(pixel_weight, self.gt)) + l_g_total += l_g_ldl + loss_dict['l_g_ldl'] = l_g_ldl # perceptual loss if self.cri_perceptual: l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) diff --git a/basicsr/models/srgan_model.py b/basicsr/models/srgan_model.py index 0c187dc2b..45387ca79 100644 --- a/basicsr/models/srgan_model.py +++ b/basicsr/models/srgan_model.py @@ -51,6 +51,11 @@ def init_training_settings(self): else: self.cri_pix = None + if train_opt.get('ldl_opt'): + self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device) + else: + self.cri_ldl = None + if train_opt.get('perceptual_opt'): self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) else: diff --git a/options/train/LDL/train_LDL_Real_x4.yml b/options/train/LDL/train_LDL_Real_x4.yml new file mode 100644 index 000000000..959af121f --- /dev/null +++ b/options/train/LDL/train_LDL_Real_x4.yml @@ -0,0 +1,182 @@ +# general settings +name: train_LDL_realworld_RRDB +model_type: RealESRGANModel +scale: 4 +num_gpu: 4 +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRGANModel ----------------- # +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: RealWorld38 + type: SingleImageDataset + dataroot_lq: datasets/RealWorld38/LR + io_backend: + type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + skip_connection: True + +# path +path: + # use the pre-trained Real-ESRNet model + pretrain_network_g: experiments/pretrained_models/RealESRGAN/RealESRNet_x4plus.pth + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: !!float 1e-2 + reduction: mean + ldl_opt: + type: L1Loss + loss_weight: !!float 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + +# Uncomment these for validation +# validation settings +val: + val_freq: !!float 5e3 + save_img: True + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500