forked from roryoday/improved-aesthetic-predictor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimple_inference.py
92 lines (74 loc) · 2.37 KB
/
simple_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" # choose GPU if you are on a multi GPU server
import json
import os
import click
import torch
from PIL import Image
from transformers import AutoProcessor, CLIPModel, logging
from MLP import MLP
logging.set_verbosity_error()
# This script will predict the aesthetic score for the given image file
class dotdict(dict):
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
@click.command()
@click.option(
"--image", help="Image file to evaluate", metavar="[DIR]", type=str, required=True
)
@click.option(
"--model", help="Directory of model", metavar="[DIR]", type=str, required=True
)
@click.option(
"--clip",
help="Huggingface model used by clip to embed images",
type=str,
default="openai/clip-vit-large-patch14",
show_default=True,
)
@click.option(
"--device",
help="Torch device type (default uses cuda if avaliable)",
type=str,
default="default",
show_default=True,
)
@click.option(
"--raw",
help="Return raw model outputs",
is_flag=True,
)
def main(**kwargs):
opts = dotdict(kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
if opts.device != "default":
device = opts.device
clip_model = CLIPModel.from_pretrained(opts.clip).to(device)
preprocess = AutoProcessor.from_pretrained(opts.clip)
dim = clip_model.projection_dim
model = MLP(dim) # CLIP embedding dim is 768 for CLIP ViT L 14
sd = torch.load(opts.model)
if "state_dict" in sd:
sd = sd["state_dict"]
model.load_state_dict(sd)
model.to(device)
model.eval()
pil_image = Image.open(opts.image)
image = preprocess(images=pil_image, return_tensors="pt")["pixel_values"].to(device)
with torch.inference_mode():
image_features = clip_model.get_image_features(image)
im_emb_arr = image_features.type(torch.float)
with torch.inference_mode():
prediction = model(im_emb_arr)
try:
with open(os.path.splitext(opts.model)[0] + ".json", "rt") as f:
y_stats = json.load(f)
except Exception:
y_stats = None
print("Aesthetic score predicted by the model:")
if y_stats is None or opts.raw:
print(prediction.item())
else:
print(prediction.item() * float(y_stats["std"]) + float(y_stats["mean"]))
if __name__ == "__main__":
main()