Skip to content

Commit

Permalink
RealPLKSR by. neosr-project
Browse files Browse the repository at this point in the history
  • Loading branch information
neosr-project authored and avan06 committed Dec 14, 2024
1 parent f241e8f commit a508c15
Show file tree
Hide file tree
Showing 4 changed files with 497 additions and 0 deletions.
167 changes: 167 additions & 0 deletions basicsr/archs/realplksr_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# type: ignore # noqa: PGH003
from functools import partial

import torch
from torch import nn
from torch.nn.init import trunc_normal_

from neosr.archs.arch_util import DySample, net_opt
from neosr.utils.registry import ARCH_REGISTRY

upscale, __ = net_opt()


class DCCM(nn.Sequential):
"""Doubled Convolutional Channel Mixer"""

def __init__(self, dim: int):
super().__init__(
nn.Conv2d(dim, dim * 2, 3, 1, 1),
nn.Mish(),
nn.Conv2d(dim * 2, dim, 3, 1, 1),
)
trunc_normal_(self[-1].weight, std=0.02)


class PLKConv2d(nn.Module):
"""Partial Large Kernel Convolutional Layer"""

def __init__(self, dim: int, kernel_size: int):
super().__init__()
self.conv = nn.Conv2d(dim, dim, kernel_size, 1, kernel_size // 2)
trunc_normal_(self.conv.weight, std=0.02)
self.idx = dim

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
x1, x2 = torch.split(x, [self.idx, x.size(1) - self.idx], dim=1)
x1 = self.conv(x1)
return torch.cat([x1, x2], dim=1)
x[:, : self.idx] = self.conv(x[:, : self.idx])
return x


class EA(nn.Module):
"""Element-wise Attention"""

def __init__(self, dim: int):
super().__init__()
self.f = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1), nn.Sigmoid())
trunc_normal_(self.f[0].weight, std=0.02)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.f(x)


class PLKBlock(nn.Module):
def __init__(
self,
dim: int,
kernel_size: int,
split_ratio: float,
norm_groups: int,
use_ea: bool = True,
):
super().__init__()

# Local Texture
self.channel_mixer = DCCM(dim)

# Long-range Dependency
pdim = int(dim * split_ratio)

# Conv Layer
self.lk = PLKConv2d(pdim, kernel_size)

# Instance-dependent modulation
if use_ea:
self.attn = EA(dim)
else:
self.attn = nn.Identity()

# Refinement
self.refine = nn.Conv2d(dim, dim, 1, 1, 0)
trunc_normal_(self.refine.weight, std=0.02)

# Group Normalization
self.norm = nn.GroupNorm(norm_groups, dim)
nn.init.constant_(self.norm.bias, 0)
nn.init.constant_(self.norm.weight, 1.0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_skip = x
x = self.channel_mixer(x)
x = self.lk(x)
x = self.attn(x)
x = self.refine(x)
x = self.norm(x)

return x + x_skip


@ARCH_REGISTRY.register()
class realplksr(nn.Module):
"""Partial Large Kernel CNNs for Efficient Super-Resolution:
https://arxiv.org/abs/2404.11848
"""

def __init__(
self,
in_ch: int = 3,
out_ch: int = 3,
dim: int = 64,
n_blocks: int = 28,
upscaling_factor: int = upscale,
kernel_size: int = 17,
split_ratio: float = 0.25,
use_ea: bool = True,
norm_groups: int = 4,
dropout: float = 0,
dysample: bool = False,
**kwargs,
):
super().__init__()

self.upscale = upscaling_factor
self.dysample = dysample
if not self.training:
dropout = 0

self.feats = nn.Sequential(
*[nn.Conv2d(in_ch, dim, 3, 1, 1)]
+ [
PLKBlock(dim, kernel_size, split_ratio, norm_groups, use_ea)
for _ in range(n_blocks)
]
+ [nn.Dropout2d(dropout)]
+ [nn.Conv2d(dim, out_ch * upscaling_factor**2, 3, 1, 1)]
)
trunc_normal_(self.feats[0].weight, std=0.02)
trunc_normal_(self.feats[-1].weight, std=0.02)

self.repeat_op = partial(
torch.repeat_interleave, repeats=upscaling_factor**2, dim=1
)

if dysample and upscaling_factor != 1:
groups = out_ch if upscaling_factor % 2 != 0 else 4
self.to_img = DySample(
in_ch * upscaling_factor**2,
out_ch,
upscaling_factor,
groups=groups,
end_convolution=True if upscaling_factor != 1 else False,
)
else:
self.to_img = nn.PixelShuffle(upscaling_factor)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.feats(x) + self.repeat_op(x)
if not self.dysample or (self.dysample and self.upscale != 1):
x = self.to_img(x)
return x


@ARCH_REGISTRY.register()
def realplksr_s(**kwargs):
return realplksr(n_blocks=12, kernel_size=13, use_ea=False, **kwargs)
22 changes: 22 additions & 0 deletions options/test/RealPLKSR/test_realplksr.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Results will be saved to neosr/experiments/results/

name = "test_realplksr"
model_type = "image"
scale = 4
#use_amp = true
#compile = true

[datasets.test_1]
name = "val_1"
type = "single"
dataroot_lq = 'C:\datasets\val\'
[val]
#tile = 200

[network_g]
type = "realplksr"
#type = "realplksr_s"
#dysample = true

[path]
pretrain_network_g = 'C:\model.pth'
131 changes: 131 additions & 0 deletions options/train/RealPLKSR/train_realplksr.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@

name = "train_realplksr"
model_type = "image"
scale = 4
use_amp = true
bfloat16 = true
fast_matmul = true
#compile = true
#manual_seed = 1024

[datasets.train]
type = "paired"
dataroot_gt = 'C:\datasets\gt\'
dataroot_lq = 'C:\datasets\lq\'
patch_size = 64
batch_size = 8
#accumulate = 1
#augmentation = [ "none", "mixup", "cutmix", "resizemix" ] # [ "cutblur" ]
#aug_prob = [ 0.5, 0.1, 0.1, 0.1 ] # [ 0.7 ]

[datasets.val]
name = "val"
type = "paired"
dataroot_gt = 'C:\datasets\val\gt\'
dataroot_lq = 'C:\datasets\val\lq\'
[val]
val_freq = 1000
#tile = 200
#[val.metrics.psnr]
#type = "calculate_psnr"
#[val.metrics.ssim]
#type = "calculate_ssim"
#[val.metrics.dists]
#type = "calculate_dists"
#better = "lower"

[path]
#pretrain_network_g = 'experiments\pretrain_g.pth'
#pretrain_network_d = 'experiments\pretrain_d.pth'

[network_g]
type = "realplksr"
#type = "realplksr_tiny"
#dysample = true

[network_d]
type = "metagan"

[train]
ema = 0.999
wavelet_guided = true
wavelet_init = 80000
#sam = "fsam"
#sam_init = 1000
#eco = true
#eco_init = 15000
#match_lq_colors = true

[train.optim_g]
type = "adan_sf"
lr = 5e-4
betas = [ 0.98, 0.92, 0.99 ]
weight_decay = 0.01
schedule_free = true
warmup_steps = 1600

[train.optim_d]
type = "adan_sf"
lr = 1e-4
betas = [ 0.98, 0.92, 0.99 ]
weight_decay = 0.01
schedule_free = true
warmup_steps = 600

# losses
[train.mssim_opt]
type = "mssim_loss"
loss_weight = 1.0

[train.consistency_opt]
type = "consistency_loss"
loss_weight = 1.0

[train.ldl_opt]
type = "ldl_loss"
loss_weight = 1.0

[train.fdl_opt]
type = "fdl_loss"
model = "vgg" # "resnet", "effnet", "inception"
loss_weight = 0.5

[train.gan_opt]
type = "gan_loss"
gan_type = "bce"
loss_weight = 0.3

#[train.msswd_opt]
#type = "msswd_loss"
#loss_weight = 1.0

#[train.perceptual_opt]
#type = "vgg_perceptual_loss"
#loss_weight = 0.5
#criterion = "huber"
##patchloss = true
##ipk = true
##patch_weight = 1.0

#[train.dists_opt]
#type = "dists_loss"
#loss_weight = 0.5

#[train.ff_opt]
#type = "ff_loss"
#loss_weight = 0.35

#[train.ncc_opt]
#type = "ncc_loss"
#loss_weight = 1.0

#[train.kl_opt]
#type = "kl_loss"
#loss_weight = 1.0

[logger]
total_iter = 1000000
save_checkpoint_freq = 1000
use_tb_logger = true
#save_tb_img = true
#print_freq = 100
Loading

0 comments on commit a508c15

Please sign in to comment.