-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
202 lines (156 loc) · 7.58 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
class CLIPLoss(nn.Module):
def __init__(self):
super().__init__()
self.labels = None
self.last_local_batch_size = None
def forward(self, outputs):
image_embed = outputs['image_embed']
text_embed = outputs['text_embed']
logit_scale = outputs['logit_scale']
local_batch_size = image_embed.size(0)
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * utils.get_rank() + torch.arange(
local_batch_size, device=image_embed.device
)
self.last_local_batch_size = local_batch_size
# normalized features
image_embed = F.normalize(image_embed, dim=-1, p=2)
text_embed = F.normalize(text_embed, dim=-1, p=2)
# gather features from all GPUs
image_embed_all, text_embed_all = \
utils.all_gather_batch([image_embed, text_embed])
# cosine similarity as logits
logits_per_image = logit_scale * image_embed @ text_embed_all.t()
logits_per_text = logit_scale * text_embed @ image_embed_all.t()
clip_loss = (F.cross_entropy(logits_per_image, self.labels) + \
F.cross_entropy(logits_per_text, self.labels)) / 2
# compute accuracy
with torch.no_grad():
pred = torch.argmax(logits_per_image, dim=-1)
correct = pred.eq(self.labels).sum()
acc = 100 * correct / local_batch_size
return {'loss': clip_loss, 'clip_loss': clip_loss, 'clip_acc': acc}
class SIMCLRLoss(nn.Module):
"""
This is the SimCLR loss in https://arxiv.org/abs/2002.05709
The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and
the memory layout that can be reshaped into shape (2, batch_size, embedding_dim).
This memory layout is consistent with the SimCLR collator in
https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py
Config params:
temperature (float): the temperature to be applied on the logits
"""
def __init__(self, temperature=0.1):
super().__init__()
self.tau = temperature
self.labels = None
self.masks = None
self.last_local_batch_size = None
def forward(self, outputs):
q_a = outputs['aug1_embed']
q_b = outputs['aug2_embed']
q_a = F.normalize(q_a, dim=-1, p=2)
q_b = F.normalize(q_b, dim=-1, p=2)
local_batch_size = q_a.size(0)
k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b])
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * utils.get_rank() + torch.arange(
local_batch_size, device=q_a.device
)
total_batch_size = local_batch_size * utils.get_world_size()
self.masks = F.one_hot(self.labels, total_batch_size) * 1e9
self.last_local_batch_size = local_batch_size
logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau
logits_aa = logits_aa - self.masks
logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau
logits_bb = logits_bb - self.masks
logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau
logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau
loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels)
loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels)
loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples
# compute accuracy
with torch.no_grad():
pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1)
correct = pred.eq(self.labels).sum()
acc = 100 * correct / local_batch_size
return {'loss': loss, 'ssl_loss': loss, 'ssl_acc': acc}
class SLIPLoss(nn.Module):
def __init__(self, ssl_loss, ssl_scale):
super().__init__()
self.clip_loss = CLIPLoss()
self.ssl_loss = ssl_loss
self.ssl_scale = ssl_scale
def forward(self, outputs):
clip_loss_dict = self.clip_loss(outputs)
clip_loss = clip_loss_dict['clip_loss']
clip_acc = clip_loss_dict['clip_acc']
ssl_loss_dict = self.ssl_loss(outputs)
ssl_loss = ssl_loss_dict['ssl_loss']
ssl_acc = ssl_loss_dict['ssl_acc']
return {'loss': clip_loss + self.ssl_scale * ssl_loss,
'clip_loss': clip_loss,
'clip_acc': clip_acc,
'ssl_loss': ssl_loss,
'ssl_acc': ssl_acc}
class RILSLoss(nn.Module):
def __init__(
self,
stu_tau=0.1,
tea_tau=0.04,
loss_weight=0.5,
):
super().__init__()
self.labels = None
self.last_local_batch_size = None
self.stu_tau = stu_tau
self.tea_tau = tea_tau
self.loss_weight = loss_weight
def forward(self, outputs):
image_embed = outputs['image_embed']
text_embed = outputs['text_embed']
logit_scale = outputs['logit_scale']
local_batch_size = image_embed.size(0)
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * utils.get_rank() + torch.arange(
local_batch_size, device=image_embed.device
)
self.last_local_batch_size = local_batch_size
# normalized features
image_embed = F.normalize(image_embed, dim=-1, p=2)
text_embed = F.normalize(text_embed, dim=-1, p=2)
# gather features from all GPUs
image_embed_all, text_embed_all = \
utils.all_gather_batch([image_embed, text_embed])
# cosine similarity as logits
logits_per_image = logit_scale * image_embed @ text_embed_all.t()
logits_per_text = logit_scale * text_embed @ image_embed_all.t()
clip_loss = (F.cross_entropy(logits_per_image, self.labels) + \
F.cross_entropy(logits_per_text, self.labels)) / 2
# compute accuracy
with torch.no_grad():
pred = torch.argmax(logits_per_image, dim=-1)
correct = pred.eq(self.labels).sum()
acc = 100 * correct / local_batch_size
masked_feat, masked_pred, unmasked_feat, mask = outputs["masked_feat"], outputs["masked_pred"], outputs["unmasked_feat"], outputs["mask"]
masked_pred = F.normalize(masked_pred[:, 1:], dim=-1, p=2) # rm CLS token
unmasked_feat = F.normalize(unmasked_feat[:, 1:], dim=-1, p=2) # rm CLS token
masked_pred_logits = 1 / self.stu_tau * masked_pred @ text_embed_all.t() # student logits
unmasked_feat_logits = (1 / self.tea_tau * unmasked_feat @ text_embed_all.t()).detach() # teacher logits
# kldivergence for masked reconstruction
recon_loss = -unmasked_feat_logits.softmax(-1) * masked_pred_logits.log_softmax(-1) + unmasked_feat_logits.softmax(-1) * unmasked_feat_logits.log_softmax(-1)
# we only calculate reconstruction loss for correct-retrieved samples
recon_loss = recon_loss.sum(dim=-1)
mask *= pred.eq(self.labels).unsqueeze(-1)
recon_loss = (recon_loss * mask).sum() / (mask.sum() + 1e-6)
recon_loss *= self.loss_weight # loss weight
loss = clip_loss + recon_loss
return {'loss': loss, 'clip_loss': clip_loss, 'clip_acc': acc, 'recon_loss': recon_loss}