Skip to content

Commit

Permalink
feat: ✨ add wadiqam pretrained models
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Jan 19, 2024
1 parent 3041923 commit 53cd2db
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 26 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ This is a image quality assessment toolbox with **pure python and pytorch**. We
---

### :triangular_flag_on_post: Updates/Changelog
- **Jan 19, 2024**. Add `wadiqam_fr` and `wadiqam_nr`. All implemented methods are usable now 🍻.
- **Dec 23, 2023**. Add `liqe` and `liqe_mix`. Thanks for the contribution from [Weixia Zhang](https://github.com/zwx8981) 🤗.
- **Oct 09, 2023**. Add datasets: [PIQ2023](https://github.com/DXOMARK-Research/PIQ2023), [GFIQA](http://database.mmsp-kn.de/gfiqa-20k-database.html). Add metric `topiq_nr-face`. We release example results on FFHQ [here](tests/ffhq_score_topiq_nr-face.csv) for reference.
- **Aug 15, 2023**. Add `st-lpips` and `laion_aes`. Refer to official repo at [ShiftTolerant-LPIPS](https://github.com/abhijay9/ShiftTolerant-LPIPS) and [improved-aesthetic-predictor](https://github.com/christophschuhmann/improved-aesthetic-predictor)
- **Aug 05, 2023**. Add our work [TOPIQ](https://arxiv.org/abs/2308.03060) with remarkable performance on almost all benchmarks via efficient Resnet50 backbone. Use it with `topiq_fr, topiq_nr, topiq_iaa` for Full-Reference, No-Reference and Aesthetic assessment respectively.
- **March 30, 2023**. Add [URanker](https://github.com/RQ-Wu/UnderwaterRanker) for IQA of under water images.
- **March 29, 2023**. :rotating_light: Hot fix of NRQM & PI.
- **March 25, 2023**. Add TreS, HyperIQA, CNNIQA, CLIPIQA.
- [**More**](docs/history_changelog.md)

---
Expand Down Expand Up @@ -160,8 +158,8 @@ Basically, we use the largest existing datasets for training, and cross dataset

| Metric Type | Reproduced Models |
| ------------- | ----------------------------- |
| FR | |
| NR | `cnniqa`, `dbcnn`, `hyperiqa` |
| FR | `wadiqam_fr` |
| NR | `cnniqa`, `dbcnn`, `hyperiqa`, `wadiqam_nr` |
| Aesthetic IQA | `nima`, `nima-vgg16-ava` |

**Important Notes:**
Expand Down
3 changes: 3 additions & 0 deletions docs/history_changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# History of Changelog

- **March 30, 2023**. Add [URanker](https://github.com/RQ-Wu/UnderwaterRanker) for IQA of under water images.
- **March 29, 2023**. :rotating_light: Hot fix of NRQM & PI.
- **March 25, 2023**. Add TreS, HyperIQA, CNNIQA, CLIPIQA.
- **Sep 1, 2022**. 1) Add pretrained models for MANIQA and AHIQ. 2) Add dataset interface for pieapp and PIPAL.
- **June 3, 2022**. Add FID metric. See [clean-fid](https://github.com/GaParmar/clean-fid) for more details.
- **March 11, 2022**. Add pretrained DBCNN, NIMA, and official model of PieAPP, paq2piq.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,41 +1,44 @@
# general settings
name: debug001_WaDIQaM_FR_TID2013
name: 005_WaDIQaM_FR_TID2013
name: 005_WaDIQaM_FR_kadid
model_type: WaDIQaMModel
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 123

# dataset and data loader settings
datasets:
train:
name: tid2013
name: general_iqa_dataset
type: GeneralFRDataset
dataroot_target: ./datasets/tid2013/distorted_images
dataroot_ref: ./datasets/tid2013/reference_images
meta_info_file: ./datasets/meta_info/meta_info_TID2013Dataset.csv
dataroot_target: './datasets/kadid10k/images'
meta_info_file: './datasets/meta_info/meta_info_KADID10kDataset.csv'
split_file: './datasets/meta_info/kadid10k_seed123.pkl'
mos_range: [1, 5]
lower_better: false

# data loader
use_shuffle: true
num_worker_per_gpu: 12
batch_size_per_gpu: 64
dataset_enlarge_ratio: 1
prefetch_mode: cpu
num_prefetch_queue: 8
num_prefetch_queue: 128

val:
name: tid2013
name: general_iqa_dataset
type: GeneralFRDataset
dataroot_target: ./datasets/tid2013/distorted_images
dataroot_ref: ./datasets/tid2013/reference_images
meta_info_file: ./datasets/meta_info/meta_info_TID2013Dataset.csv

dataroot_target: './datasets/kadid10k/images'
meta_info_file: './datasets/meta_info/meta_info_KADID10kDataset.csv'
split_file: './datasets/meta_info/kadid10k_seed123.pkl'
mos_range: [1, 5]
lower_better: false

num_worker_per_gpu: 4
batch_size_per_gpu: 32

# network structures
network:
type: WaDIQaM
metric_mode: FR
metric_type: FR

# path
path:
Expand Down
104 changes: 104 additions & 0 deletions options/train/WaDIQaM/train_WaDIQaM_NR_koniq.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# general settings
name: 005_WaDIQaM_FR_koniq
model_type: WaDIQaMModel
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 123

# dataset and data loader settings
datasets:
train:
name: koniq10k
type: GeneralNRDataset
dataroot_target: ./datasets/koniq10k/512x384
meta_info_file: ./datasets/meta_info/meta_info_KonIQ10kDataset.csv
split_file: ./datasets/meta_info/koniq10k_official.pkl
mos_range: [0, 100]
lower_better: false
mos_normalize: true

# data loader
use_shuffle: true
num_worker_per_gpu: 12
batch_size_per_gpu: 64
dataset_enlarge_ratio: 1
prefetch_mode: cpu
num_prefetch_queue: 128

val:
name: koniq10k
type: GeneralNRDataset
dataroot_target: ./datasets/koniq10k/512x384
meta_info_file: ./datasets/meta_info/meta_info_KonIQ10kDataset.csv
split_file: ./datasets/meta_info/koniq10k_official.pkl
mos_range: [0, 100]
lower_better: false
mos_normalize: true

num_worker_per_gpu: 4
batch_size_per_gpu: 32

# network structures
network:
type: WaDIQaM
metric_type: NR
pretrained_model_path: ./experiments/005_WaDIQaM_FR_kadid/models/net_best.pth
load_feature_weight_only: True

# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~

# training settings
train:
optim:
type: Adam
lr_basemodel: !!float 1e-4
lr_fc_layers: !!float 1e-4
lr: !!float 1e-4

scheduler:
type: MultiStepLR
milestones: [400, 800, 1200]
gamma: 1

total_iter: 40000 #
warmup_iter: -1 # no warm up

# losses
mos_loss_opt:
type: PLCCLoss
loss_weight: !!float 1.0

# validation settings
val:
val_freq: !!float 100
save_img: false
pbar: true

key_metric: srcc # if this metric improve, update all metrics. If not specified, each best metric results will be updated separately
metrics:
srcc:
type: calculate_srcc

plcc:
type: calculate_plcc

krcc:
type: calculate_krcc

# logging settings
logger:
print_freq: 20
save_checkpoint_freq: !!float 5e9
save_latest_freq: !!float 5e2
use_tb_logger: true
wandb:
project: ~
resume_id: ~

# dist training settings
dist_params:
backend: nccl
port: 29500
28 changes: 20 additions & 8 deletions pyiqa/archs/wadiqam_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@
import torch
import torch.nn as nn
from pyiqa.utils.registry import ARCH_REGISTRY
from pyiqa.archs.arch_util import load_pretrained_network

from typing import Union, List, cast


default_model_urls = {
'wadiqam_fr_kadid': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/WaDIQaM-kadid-f7541ea5.pth',
'wadiqam_nr_koniq': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/WaDIQaM-NR-koniq-aaffea29.pth',
}


def make_layers(cfg: List[Union[str, int]]) -> nn.Sequential:
layers: List[nn.Module] = []
in_channels = 3
Expand All @@ -38,7 +45,7 @@ def make_layers(cfg: List[Union[str, int]]) -> nn.Sequential:
class WaDIQaM(nn.Module):
"""WaDIQaM model.
Args:
metric_mode (String): Choose metric mode.
metric_type (String): Choose metric mode.
weighted_average (Boolean): Average the weight.
train_patch_num (int): Number of patch trained. Default: 32.
pretrained_model_path (String): The pretrained model path.
Expand All @@ -49,7 +56,9 @@ class WaDIQaM(nn.Module):

def __init__(
self,
metric_mode='FR',
metric_type='FR',
model_name='wadiqam_fr_kadid',
pretrained=True,
weighted_average=True,
train_patch_num=32,
pretrained_model_path=None,
Expand All @@ -63,8 +72,8 @@ def __init__(

self.train_patch_num = train_patch_num
self.patch_size = 32 # This cannot be changed due to network design
self.metric_mode = metric_mode
fc_in_channel = 512 * 3 if metric_mode == 'FR' else 512
self.metric_type = metric_type
fc_in_channel = 512 * 3 if metric_type == 'FR' else 512
self.eps = eps

self.fc_q = nn.Sequential(
Expand All @@ -86,19 +95,22 @@ def __init__(

if pretrained_model_path is not None:
self.load_pretrained_network(pretrained_model_path, load_feature_weight_only)
elif pretrained:
self.metric_type = model_name.split('_')[1].upper()
load_pretrained_network(self, default_model_urls[model_name], True, weight_keys='params')

def load_pretrained_network(self, model_path, load_feature_weight_only=False):
print(f'Loading pretrained model from {model_path}')
state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
state_dict = torch.load(model_path, map_location=torch.device('cpu'))['params']
if load_feature_weight_only:
print('Only load backbone feature net')
new_state_dict = {}
for k in state_dict.keys():
if 'features' in k:
new_state_dict[k] = state_dict[k]
self.net.load_state_dict(new_state_dict, strict=False)
self.load_state_dict(new_state_dict, strict=False)
else:
self.net.load_state_dict(state_dict, strict=True)
self.load_state_dict(state_dict, strict=True)

def _get_random_patches(self, x, y=None):
"""train with random crop patches"""
Expand Down Expand Up @@ -165,7 +177,7 @@ def forward(self, x, y=None):
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A reference tensor. Shape :math:`(N, C, H, W)`.
"""
if self.metric_mode == 'FR':
if self.metric_type == 'FR':
assert y is not None, 'Full reference metric requires reference input'
x_patches, y_patches = self.get_patches(x, y)
feat_img = self.extract_features(x_patches)
Expand Down
16 changes: 16 additions & 0 deletions pyiqa/default_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,4 +495,20 @@
},
'metric_mode': 'NR',
},
'wadiqam_fr': {
'metric_opts': {
'type': 'WaDIQaM',
'metric_type': 'FR',
'model_name': 'wadiqam_fr_kadid',
},
'metric_mode': 'FR',
},
'wadiqam_nr': {
'metric_opts': {
'type': 'WaDIQaM',
'metric_type': 'NR',
'model_name': 'wadiqam_nr_koniq',
},
'metric_mode': 'NR',
},
})
1 change: 1 addition & 0 deletions tests/FR_benchmark_results.csv
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ lpips,0.9005/0.9233/0.7499,0.7672/0.869/0.6768,0.711/0.7151/0.5221,0.7529/0.7445
dists,0.9324/0.9296/0.7644,0.8392/0.9051/0.7283,0.7032/0.6648/0.4861,0.7538/0.7077/0.5212
pieapp,0.838/0.8968/0.7109,0.8577/0.9182/0.7491,0.6443/0.7971/0.6089,0.7195/0.8438/0.6571
ahiq,0.8234/0.8273/0.6168,0.8039/0.8967/0.7066,0.6772/0.6807/0.4842,0.7379/0.7075/0.5127
wadiqam_fr,0.9087/0.922/0.7461,0.9163/0.9308/0.7584,0.8221/0.8222/0.6245,0.8424/0.8264/0.628
topiq_fr,0.9589/0.9674/0.8379,0.9542/0.9759/0.8617,0.9044/0.9226/0.7554,0.9158/0.9165/0.7441
1 change: 1 addition & 0 deletions tests/NR_benchmark_results.csv
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pi,0.5201/0.4615/0.3139,0.4688/0.4573/0.3132,0.4627/0.3479/0.2398,,0.7353/0.7307
nima,0.4993/0.5071/0.348,0.7156/0.6662/0.4816,0.3324/0.321/0.2159,,0.5153/0.5201/0.3558,
paq2piq,0.7542/0.7188/0.5302,0.7062/0.643/0.4622,0.5776/0.4011/0.2838,,0.775/0.8289/0.6207,
cnniqa,0.6372/0.6089/0.4257,0.7934/0.7551/0.558,0.398/0.1769/0.117,,0.7272/0.7397/0.5263,
wadiqam_nr,0.6631/0.6675/0.4721,0.83/0.8046/0.6129,0.3517/0.1544/0.1002,,,
dbcnn,0.774/0.7562/0.5563,0.9197/0.9034/0.7338,0.5141/0.3855/0.2691,,0.8549/0.8473/0.639,
musiq-ava,0.6001/0.5954/0.4235,0.589/0.5273/0.3714,,,,
musiq-koniq,0.8295/0.7889/0.5986,0.8958/0.8654/0.6817,0.6814/0.575/0.4131,0.5128/0.4978/0.3437,0.8626/0.8676/0.6649,
Expand Down

0 comments on commit 53cd2db

Please sign in to comment.