Skip to content

Commit

Permalink
feat: 💥 add qalign
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Jan 31, 2024
1 parent bc7ecb3 commit 9cff9d3
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 2 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ This is a image quality assessment toolbox with **pure python and pytorch**. We
---

### :triangular_flag_on_post: Updates/Changelog
- :boom: **Jan 31, 2024**. Add `qalign` for both NR and IAA. It is our most powerful unified metric based on large vision-language models, and shows remarkable performance and robustness. Refer [Q-Align](https://github.com/Q-Future/Q-Align) for more details. Use it with the following codes:
```
qalign = create_metric('qalign').cuda()
quality_score = qalign(input, task_='quality')
aesthetic_score = qalign(input, task_='aesthetic')
```
- **Jan 19, 2024**. Add `wadiqam_fr` and `wadiqam_nr`. All implemented methods are usable now 🍻.
- **Dec 23, 2023**. Add `liqe` and `liqe_mix`. Thanks for the contribution from [Weixia Zhang](https://github.com/zwx8981) 🤗.
- **Oct 09, 2023**. Add datasets: [PIQ2023](https://github.com/DXOMARK-Research/PIQ2023), [GFIQA](http://database.mmsp-kn.de/gfiqa-20k-database.html). Add metric `topiq_nr-face`. We release example results on FFHQ [here](tests/ffhq_score_topiq_nr-face.csv) for reference.
Expand Down Expand Up @@ -224,7 +230,7 @@ If you find our codes helpful to your research, please consider to use the follo
}
```

Please also consider to cite our new work `TOPIQ` if it is useful to you:
Please also consider to cite our works on image quality assessment if it is useful to you:
```
@article{chen2023topiq,
title={TOPIQ: A Top-down Approach from Semantics to Distortions for Image Quality Assessment},
Expand All @@ -233,6 +239,16 @@ Please also consider to cite our new work `TOPIQ` if it is useful to you:
year={2023}
}
```
```
@article{wu2023qalign,
title={Q-Align: Teaching LMMs for Visual Scoring via Discrete Text-Defined Levels},
author={Wu, Haoning and Zhang, Zicheng and Zhang, Weixia and Chen, Chaofeng and Li, Chunyi and Liao, Liang and Wang, Annan and Zhang, Erli and Sun, Wenxiu and Yan, Qiong and Min, Xiongkuo and Zhai, Guangtai and Lin, Weisi},
journal={arXiv preprint arXiv:2312.17090},
year={2023},
institution={Nanyang Technological University and Shanghai Jiao Tong University and Sensetime Research},
note={Equal Contribution by Wu, Haoning and Zhang, Zicheng. Project Lead by Wu, Haoning. Corresponding Authors: Zhai, Guangtai and Lin, Weisi.}
}
```

## :heart: Acknowledgement

Expand Down
65 changes: 65 additions & 0 deletions pyiqa/archs/qalign_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
r"""Q-Align: All-in-one Foundation Model for visual scoring.
Reference:
@article{wu2023qalign,
title={Q-Align: Teaching LMMs for Visual Scoring via Discrete Text-Defined Levels},
author={Wu, Haoning and Zhang, Zicheng and Zhang, Weixia and Chen, Chaofeng and Li, Chunyi and Liao, Liang and Wang, Annan and Zhang, Erli and Sun, Wenxiu and Yan, Qiong and Min, Xiongkuo and Zhai, Guangtai and Lin, Weisi},
journal={arXiv preprint arXiv:2312.17090},
year={2023},
institution={Nanyang Technological University and Shanghai Jiao Tong University and Sensetime Research},
note={Equal Contribution by Wu, Haoning and Zhang, Zicheng. Project Lead by Wu, Haoning. Corresponding Authors: Zhai, Guangtai and Lin, Weisi.}
}
Reference url: https://github.com/Q-Future/Q-Align
"""
import torch
from torch import nn
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

from .constants import OPENAI_CLIP_MEAN
from pyiqa.utils.registry import ARCH_REGISTRY
from transformers import CLIPImageProcessor
import torchvision.transforms.functional as F
from PIL import Image

def expand2square(pil_img):
background_color = tuple(int(x*255) for x in OPENAI_CLIP_MEAN)
width, height = pil_img.size
maxwh = max(width, height)
result = Image.new(pil_img.mode, (maxwh, maxwh), background_color)
result.paste(pil_img, ((maxwh - width) // 2, (maxwh - height) // 2))
return result


@ARCH_REGISTRY.register()
class QAlign(nn.Module):
def __init__(self, ) -> None:
super().__init__()

# load model
self.model = AutoModelForCausalLM.from_pretrained(
"q-future/one-align",
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
)
self.image_processor = CLIPImageProcessor.from_pretrained("q-future/one-align")

def preprocess(self, x):
assert x.shape[0] == 1, "Currently, only support batch size 1."
images = F.to_pil_image(x[0])
images = expand2square(images)
image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half()
return image_tensor.to(x.device)

def forward(self, x, task_="quality", input_="image"):
"""
task_: str, optional [quality, aesthetic]
"""
if input == "image":
image_tensor = self.preprocess(x)
score = self.model.score(images=None, image_tensor=image_tensor, task_=task_, input_=input_)
else:
raise NotImplementedError(f"Input type {input_} is not supported yet.")

return score
6 changes: 6 additions & 0 deletions pyiqa/default_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,4 +511,10 @@
},
'metric_mode': 'NR',
},
'qalign': {
'metric_opts': {
'type': 'QAlign',
},
'metric_mode': 'NR',
}
})
21 changes: 21 additions & 0 deletions pyiqa/utils/img_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,27 @@ def is_image_file(filename):
return any(filename.lower().endswith(extension) for extension in Image.registered_extensions())


def imread2pil(img_source, rgb=False):
"""Read image to tensor.
Args:
img_source (str, bytes, or PIL.Image): image filepath string, image contents as a bytearray or a PIL Image instance
rgb: convert input to RGB if true
"""
if type(img_source) == bytes:
img = Image.open(io.BytesIO(img_source))
elif type(img_source) == str:
assert is_image_file(img_source), f'{img_source} is not a valid image file.'
img = Image.open(img_source)
elif isinstance(img_source, Image.Image):
img = img_source
else:
raise Exception("Unsupported source type")
if rgb:
img = img.convert('RGB')
return img


def imread2tensor(img_source, rgb=False):
"""Read image to tensor.
Expand Down
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ yapf
einops
imgaug
openai-clip
facexlib
facexlib
transformers>=4.36.1
accelerate
icecream
sentencepiece

0 comments on commit 9cff9d3

Please sign in to comment.