-
Notifications
You must be signed in to change notification settings - Fork 249
/
byol_pytorch.py
284 lines (223 loc) · 8.97 KB
/
byol_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import copy
import random
from functools import wraps
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from torchvision import transforms as T
# helper functions
def default(val, def_val):
return def_val if val is None else val
def flatten(t):
return t.reshape(t.shape[0], -1)
def singleton(cache_key):
def inner_fn(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
instance = getattr(self, cache_key)
if instance is not None:
return instance
instance = fn(self, *args, **kwargs)
setattr(self, cache_key, instance)
return instance
return wrapper
return inner_fn
def get_module_device(module):
return next(module.parameters()).device
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
def MaybeSyncBatchnorm(is_distributed = None):
is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d
# loss fn
def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
# augmentation utils
class RandomApply(nn.Module):
def __init__(self, fn, p):
super().__init__()
self.fn = fn
self.p = p
def forward(self, x):
if random.random() > self.p:
return x
return self.fn(x)
# exponential moving average
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
# MLP class for projector and predictor
def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
return nn.Sequential(
nn.Linear(dim, hidden_size),
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)
def SimSiamMLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
return nn.Sequential(
nn.Linear(dim, hidden_size, bias=False),
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, hidden_size, bias=False),
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size, bias=False),
MaybeSyncBatchnorm(sync_batchnorm)(projection_size, affine=False)
)
# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets
class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False, sync_batchnorm = None):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.use_simsiam_mlp = use_simsiam_mlp
self.sync_batchnorm = sync_batchnorm
self.hidden = {}
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, input, output):
device = input[0].device
self.hidden[device] = flatten(output)
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP
projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size, sync_batchnorm = self.sync_batchnorm)
return projector.to(hidden)
def get_representation(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
self.hidden.clear()
_ = self.net(x)
hidden = self.hidden[x.device]
self.hidden.clear()
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, return_projection = True):
representation = self.get_representation(x)
if not return_projection:
return representation
projector = self._get_projector(representation)
projection = projector(representation)
return projection, representation
# main class
class BYOL(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_size = 256,
projection_hidden_size = 4096,
augment_fn = None,
augment_fn2 = None,
moving_average_decay = 0.99,
use_momentum = True,
sync_batchnorm = None
):
super().__init__()
self.net = net
# default SimCLR augmentation
DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)
self.online_encoder = NetWrapper(
net,
projection_size,
projection_hidden_size,
layer = hidden_layer,
use_simsiam_mlp = not use_momentum,
sync_batchnorm = sync_batchnorm
)
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
@singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None
def update_moving_average(self):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
def forward(
self,
x,
return_embedding = False,
return_projection = True
):
assert not (self.training and x.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'
if return_embedding:
return self.online_encoder(x, return_projection = return_projection)
image_one, image_two = self.augment1(x), self.augment2(x)
images = torch.cat((image_one, image_two), dim = 0)
online_projections, _ = self.online_encoder(images)
online_predictions = self.online_predictor(online_projections)
online_pred_one, online_pred_two = online_predictions.chunk(2, dim = 0)
with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_projections, _ = target_encoder(images)
target_projections = target_projections.detach()
target_proj_one, target_proj_two = target_projections.chunk(2, dim = 0)
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
loss = loss_one + loss_two
return loss.mean()