forked from XPixelGroup/BasicSR
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Making PLKSR stable for real-world SISR dslisleedh/PLKSR#4 https://github.com/neosr-project/neosr/blob/master/neosr/archs/realplksr_arch.py
- Loading branch information
1 parent
f241e8f
commit a508c15
Showing
4 changed files
with
497 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# type: ignore # noqa: PGH003 | ||
from functools import partial | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn.init import trunc_normal_ | ||
|
||
from neosr.archs.arch_util import DySample, net_opt | ||
from neosr.utils.registry import ARCH_REGISTRY | ||
|
||
upscale, __ = net_opt() | ||
|
||
|
||
class DCCM(nn.Sequential): | ||
"""Doubled Convolutional Channel Mixer""" | ||
|
||
def __init__(self, dim: int): | ||
super().__init__( | ||
nn.Conv2d(dim, dim * 2, 3, 1, 1), | ||
nn.Mish(), | ||
nn.Conv2d(dim * 2, dim, 3, 1, 1), | ||
) | ||
trunc_normal_(self[-1].weight, std=0.02) | ||
|
||
|
||
class PLKConv2d(nn.Module): | ||
"""Partial Large Kernel Convolutional Layer""" | ||
|
||
def __init__(self, dim: int, kernel_size: int): | ||
super().__init__() | ||
self.conv = nn.Conv2d(dim, dim, kernel_size, 1, kernel_size // 2) | ||
trunc_normal_(self.conv.weight, std=0.02) | ||
self.idx = dim | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
if self.training: | ||
x1, x2 = torch.split(x, [self.idx, x.size(1) - self.idx], dim=1) | ||
x1 = self.conv(x1) | ||
return torch.cat([x1, x2], dim=1) | ||
x[:, : self.idx] = self.conv(x[:, : self.idx]) | ||
return x | ||
|
||
|
||
class EA(nn.Module): | ||
"""Element-wise Attention""" | ||
|
||
def __init__(self, dim: int): | ||
super().__init__() | ||
self.f = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1), nn.Sigmoid()) | ||
trunc_normal_(self.f[0].weight, std=0.02) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return x * self.f(x) | ||
|
||
|
||
class PLKBlock(nn.Module): | ||
def __init__( | ||
self, | ||
dim: int, | ||
kernel_size: int, | ||
split_ratio: float, | ||
norm_groups: int, | ||
use_ea: bool = True, | ||
): | ||
super().__init__() | ||
|
||
# Local Texture | ||
self.channel_mixer = DCCM(dim) | ||
|
||
# Long-range Dependency | ||
pdim = int(dim * split_ratio) | ||
|
||
# Conv Layer | ||
self.lk = PLKConv2d(pdim, kernel_size) | ||
|
||
# Instance-dependent modulation | ||
if use_ea: | ||
self.attn = EA(dim) | ||
else: | ||
self.attn = nn.Identity() | ||
|
||
# Refinement | ||
self.refine = nn.Conv2d(dim, dim, 1, 1, 0) | ||
trunc_normal_(self.refine.weight, std=0.02) | ||
|
||
# Group Normalization | ||
self.norm = nn.GroupNorm(norm_groups, dim) | ||
nn.init.constant_(self.norm.bias, 0) | ||
nn.init.constant_(self.norm.weight, 1.0) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x_skip = x | ||
x = self.channel_mixer(x) | ||
x = self.lk(x) | ||
x = self.attn(x) | ||
x = self.refine(x) | ||
x = self.norm(x) | ||
|
||
return x + x_skip | ||
|
||
|
||
@ARCH_REGISTRY.register() | ||
class realplksr(nn.Module): | ||
"""Partial Large Kernel CNNs for Efficient Super-Resolution: | ||
https://arxiv.org/abs/2404.11848 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_ch: int = 3, | ||
out_ch: int = 3, | ||
dim: int = 64, | ||
n_blocks: int = 28, | ||
upscaling_factor: int = upscale, | ||
kernel_size: int = 17, | ||
split_ratio: float = 0.25, | ||
use_ea: bool = True, | ||
norm_groups: int = 4, | ||
dropout: float = 0, | ||
dysample: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
|
||
self.upscale = upscaling_factor | ||
self.dysample = dysample | ||
if not self.training: | ||
dropout = 0 | ||
|
||
self.feats = nn.Sequential( | ||
*[nn.Conv2d(in_ch, dim, 3, 1, 1)] | ||
+ [ | ||
PLKBlock(dim, kernel_size, split_ratio, norm_groups, use_ea) | ||
for _ in range(n_blocks) | ||
] | ||
+ [nn.Dropout2d(dropout)] | ||
+ [nn.Conv2d(dim, out_ch * upscaling_factor**2, 3, 1, 1)] | ||
) | ||
trunc_normal_(self.feats[0].weight, std=0.02) | ||
trunc_normal_(self.feats[-1].weight, std=0.02) | ||
|
||
self.repeat_op = partial( | ||
torch.repeat_interleave, repeats=upscaling_factor**2, dim=1 | ||
) | ||
|
||
if dysample and upscaling_factor != 1: | ||
groups = out_ch if upscaling_factor % 2 != 0 else 4 | ||
self.to_img = DySample( | ||
in_ch * upscaling_factor**2, | ||
out_ch, | ||
upscaling_factor, | ||
groups=groups, | ||
end_convolution=True if upscaling_factor != 1 else False, | ||
) | ||
else: | ||
self.to_img = nn.PixelShuffle(upscaling_factor) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.feats(x) + self.repeat_op(x) | ||
if not self.dysample or (self.dysample and self.upscale != 1): | ||
x = self.to_img(x) | ||
return x | ||
|
||
|
||
@ARCH_REGISTRY.register() | ||
def realplksr_s(**kwargs): | ||
return realplksr(n_blocks=12, kernel_size=13, use_ea=False, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Results will be saved to neosr/experiments/results/ | ||
|
||
name = "test_realplksr" | ||
model_type = "image" | ||
scale = 4 | ||
#use_amp = true | ||
#compile = true | ||
|
||
[datasets.test_1] | ||
name = "val_1" | ||
type = "single" | ||
dataroot_lq = 'C:\datasets\val\' | ||
[val] | ||
#tile = 200 | ||
|
||
[network_g] | ||
type = "realplksr" | ||
#type = "realplksr_s" | ||
#dysample = true | ||
|
||
[path] | ||
pretrain_network_g = 'C:\model.pth' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
|
||
name = "train_realplksr" | ||
model_type = "image" | ||
scale = 4 | ||
use_amp = true | ||
bfloat16 = true | ||
fast_matmul = true | ||
#compile = true | ||
#manual_seed = 1024 | ||
|
||
[datasets.train] | ||
type = "paired" | ||
dataroot_gt = 'C:\datasets\gt\' | ||
dataroot_lq = 'C:\datasets\lq\' | ||
patch_size = 64 | ||
batch_size = 8 | ||
#accumulate = 1 | ||
#augmentation = [ "none", "mixup", "cutmix", "resizemix" ] # [ "cutblur" ] | ||
#aug_prob = [ 0.5, 0.1, 0.1, 0.1 ] # [ 0.7 ] | ||
|
||
[datasets.val] | ||
name = "val" | ||
type = "paired" | ||
dataroot_gt = 'C:\datasets\val\gt\' | ||
dataroot_lq = 'C:\datasets\val\lq\' | ||
[val] | ||
val_freq = 1000 | ||
#tile = 200 | ||
#[val.metrics.psnr] | ||
#type = "calculate_psnr" | ||
#[val.metrics.ssim] | ||
#type = "calculate_ssim" | ||
#[val.metrics.dists] | ||
#type = "calculate_dists" | ||
#better = "lower" | ||
|
||
[path] | ||
#pretrain_network_g = 'experiments\pretrain_g.pth' | ||
#pretrain_network_d = 'experiments\pretrain_d.pth' | ||
|
||
[network_g] | ||
type = "realplksr" | ||
#type = "realplksr_tiny" | ||
#dysample = true | ||
|
||
[network_d] | ||
type = "metagan" | ||
|
||
[train] | ||
ema = 0.999 | ||
wavelet_guided = true | ||
wavelet_init = 80000 | ||
#sam = "fsam" | ||
#sam_init = 1000 | ||
#eco = true | ||
#eco_init = 15000 | ||
#match_lq_colors = true | ||
|
||
[train.optim_g] | ||
type = "adan_sf" | ||
lr = 5e-4 | ||
betas = [ 0.98, 0.92, 0.99 ] | ||
weight_decay = 0.01 | ||
schedule_free = true | ||
warmup_steps = 1600 | ||
|
||
[train.optim_d] | ||
type = "adan_sf" | ||
lr = 1e-4 | ||
betas = [ 0.98, 0.92, 0.99 ] | ||
weight_decay = 0.01 | ||
schedule_free = true | ||
warmup_steps = 600 | ||
|
||
# losses | ||
[train.mssim_opt] | ||
type = "mssim_loss" | ||
loss_weight = 1.0 | ||
|
||
[train.consistency_opt] | ||
type = "consistency_loss" | ||
loss_weight = 1.0 | ||
|
||
[train.ldl_opt] | ||
type = "ldl_loss" | ||
loss_weight = 1.0 | ||
|
||
[train.fdl_opt] | ||
type = "fdl_loss" | ||
model = "vgg" # "resnet", "effnet", "inception" | ||
loss_weight = 0.5 | ||
|
||
[train.gan_opt] | ||
type = "gan_loss" | ||
gan_type = "bce" | ||
loss_weight = 0.3 | ||
|
||
#[train.msswd_opt] | ||
#type = "msswd_loss" | ||
#loss_weight = 1.0 | ||
|
||
#[train.perceptual_opt] | ||
#type = "vgg_perceptual_loss" | ||
#loss_weight = 0.5 | ||
#criterion = "huber" | ||
##patchloss = true | ||
##ipk = true | ||
##patch_weight = 1.0 | ||
|
||
#[train.dists_opt] | ||
#type = "dists_loss" | ||
#loss_weight = 0.5 | ||
|
||
#[train.ff_opt] | ||
#type = "ff_loss" | ||
#loss_weight = 0.35 | ||
|
||
#[train.ncc_opt] | ||
#type = "ncc_loss" | ||
#loss_weight = 1.0 | ||
|
||
#[train.kl_opt] | ||
#type = "kl_loss" | ||
#loss_weight = 1.0 | ||
|
||
[logger] | ||
total_iter = 1000000 | ||
save_checkpoint_freq = 1000 | ||
use_tb_logger = true | ||
#save_tb_img = true | ||
#print_freq = 100 |
Oops, something went wrong.