-
Notifications
You must be signed in to change notification settings - Fork 1
/
clip_score.py
62 lines (45 loc) · 2.06 KB
/
clip_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn
from CLIP.clip import clip_feature_surgery
from torch.nn import functional as F
from torchvision.transforms import Compose, Resize, InterpolationMode
img_resize = Compose([
Resize((224, 224), interpolation=InterpolationMode.BICUBIC, antialias=True),
])
def get_clip_score_from_feature(model, image, text_features, temp=100.):
# size of image: [b, 3, 224, 224]
image = img_resize(image)
image_features = model.encode_image(image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
probs = temp * clip_feature_surgery(image_features, text_features)[:, 1:, :]
similarity = torch.mean(probs.softmax(dim=-1), dim=1, keepdim=False)
loss = 1. - similarity[:, 0]
loss = torch.sum(loss) / len(loss)
return loss
class L_clip_from_feature(nn.Module):
def __init__(self, temp=100.):
super(L_clip_from_feature, self).__init__()
self.temp = temp
for param in self.parameters():
param.requires_grad = False
def forward(self, model, x, text_features):
k1 = get_clip_score_from_feature(model, x, text_features, self.temp)
return k1
def get_clip_score_MSE(res_model, pred, inp, weight):
stack = img_resize(torch.cat([pred, inp], dim=1))
pred_image_features = res_model.encode_image(stack[:, :3, :, :])
inp_image_features = res_model.encode_image(stack[:, 3:, :, :])
MSE_loss = 0
for feature_index in range(len(weight)):
MSE_loss = MSE_loss + weight[feature_index] * F.mse_loss(pred_image_features[1][feature_index], inp_image_features[1][feature_index])
return MSE_loss
class L_clip_MSE(nn.Module):
def __init__(self):
super(L_clip_MSE, self).__init__()
for param in self.parameters():
param.requires_grad = False
def forward(self, model, pred, inp, weight=None):
if weight is None:
weight = [1.0, 1.0, 1.0, 1.0, 0.5]
res = get_clip_score_MSE(model, pred, inp, weight)
return res