-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
376 lines (317 loc) · 14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import argparse
import json
import math
import os
import random
import signal
import subprocess
import sys
import time
from PIL import Image, ImageOps, ImageFilter
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms
import torch.distributed.nn
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('data', type=Path, metavar='DIR',
help='path to dataset')
parser.add_argument('--workers', default=8, type=int, metavar='N',
help='number of data loader workers')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--stop-at-epoch', default=100, type=int, metavar='N',
help='interrupt training at the end of this epoch')
parser.add_argument('--batch-size', default=2048, type=int, metavar='N',
help='mini-batch size')
parser.add_argument('--learning-rate-weights', default=0.3, type=float, metavar='LR',
help='base learning rate for weights')
parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
help='weight decay')
parser.add_argument('--inv-coef', default=8., type=float, metavar='L',
help='coef of inv_loss')
parser.add_argument('--cov-coef', default=8., type=float, metavar='L',
help='coef of cov_loss')
parser.add_argument('--projector', default='8192-8192-8192', type=str,
metavar='MLP', help='projector MLP')
parser.add_argument('--num-classes', default=1000, type=int, metavar='N',
help='number of classes')
parser.add_argument('--identity-pred', action='store_true',
help='Predictor is fixed to an identity function.')
parser.add_argument('--standardization-off',action='store_false',
help='Standardization before predictor.')
parser.add_argument('--print-freq', default=100, type=int, metavar='N',
help='print frequency')
parser.add_argument('--checkpoint-dir', default='./checkpoint/', type=Path,
metavar='DIR', help='path to checkpoint directory')
def main():
args = parser.parse_args()
args.ngpus_per_node = torch.cuda.device_count()
# single-node distributed training
args.rank = 0
args.dist_url = 'tcp://127.0.0.1:%d'%random.randint(10000, 60000)
args.world_size = args.ngpus_per_node
torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)
def main_worker(gpu, args):
args.rank += gpu
torch.distributed.init_process_group(
backend='nccl', init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
if args.rank == 0:
args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1)
print(' '.join(sys.argv))
print(' '.join(sys.argv), file=stats_file)
torch.cuda.set_device(gpu)
torch.backends.cudnn.benchmark = True
model = Model(args).cuda(gpu)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
param_enc = []
param_predictor = []
for name, param in model.named_parameters():
if 'predictor' in name:
param_predictor.append(param)
else:
param_enc.append(param)
parameters = [{'params': param_enc}, {'params': param_predictor}]
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
optimizer = LARS(parameters, lr=0., weight_decay=args.weight_decay,
weight_decay_filter=True,
lars_adaptation_filter=True)
scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
# automatically resume from checkpoint if it exists
if (args.checkpoint_dir / 'checkpoint.pth').is_file():
ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth',
map_location='cpu')
start_epoch = ckpt['epoch']
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
scaler.load_state_dict(ckpt['scaler'])
else:
start_epoch = 0
dataset = torchvision.datasets.ImageFolder(args.data / 'train', Transform())
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
assert args.batch_size % args.world_size == 0
per_device_batch_size = args.batch_size // args.world_size
loader = torch.utils.data.DataLoader(
dataset, batch_size=per_device_batch_size, num_workers=args.workers,
pin_memory=True, sampler=sampler, drop_last=True)
start_time = time.time()
for epoch in range(start_epoch, args.epochs):
sampler.set_epoch(epoch)
for step, ((y1, y2), labels) in enumerate(loader, start=epoch * len(loader)):
y1 = y1.cuda(gpu, non_blocking=True)
y2 = y2.cuda(gpu, non_blocking=True)
labels = labels.cuda(gpu, non_blocking=True)
adjust_learning_rate(args, optimizer, loader, step)
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=args.use_amp):
inv_loss, cov_loss, lincls_loss, lincls_pred = model.forward(y1, y2, labels)
loss = args.inv_coef*inv_loss + args.cov_coef*cov_loss + lincls_loss
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if step % args.print_freq == 0:
if args.rank == 0:
acc1, acc5 = accuracy(
output=lincls_pred,
target=labels,
topk=(1,5))
stats = dict(epoch=epoch, step=step,
lr_weights=optimizer.param_groups[0]['lr'],
lr_biases=optimizer.param_groups[1]['lr'],
inv_loss=inv_loss.item(),
cov_loss=cov_loss.item(),
lincls_loss=lincls_loss.item(),
acc1=acc1.item(),
acc5=acc5.item(),
time=int(time.time() - start_time))
print(json.dumps(stats))
print(json.dumps(stats), file=stats_file)
if args.rank == 0:
# save checkpoint
state = dict(epoch=epoch + 1, model=model.state_dict(),
optimizer=optimizer.state_dict(), scaler=scaler.state_dict())
if (epoch+1)%20==0 or (epoch+1)%100==0:
torch.save(state, os.path.join(args.checkpoint_dir, 'checkpoint_epoch_%d.pth'%(epoch+1)))
torch.save(state, args.checkpoint_dir / 'checkpoint.pth')
if epoch > args.stop_at_epoch:
break
if torch.isnan(loss):
break
if args.rank == 0:
# save final model
torch.save(model.module.backbone.state_dict(),
args.checkpoint_dir / 'resnet50.pth')
def adjust_learning_rate(args, optimizer, loader, step):
max_steps = args.epochs * len(loader)
warmup_steps = 10 * len(loader)
base_lr = args.batch_size / 256
if step < warmup_steps:
lr = base_lr * step / warmup_steps
else:
step -= warmup_steps
max_steps -= warmup_steps
q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
end_lr = 0.
lr = base_lr * q + end_lr * (1 - q)
optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights
optimizer.param_groups[1]['lr'] = lr * args.learning_rate_weights
class Model(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.backbone = torchvision.models.resnet50(zero_init_residual=True)
self.backbone.fc = nn.Identity()
# projector
sizes = [2048] + list(map(int, args.projector.split('-')))
layers = []
for i in range(len(sizes) - 2):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
layers.append(nn.BatchNorm1d(sizes[i + 1]))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
self.projector = nn.Sequential(*layers)
# predictor
dim = int(args.projector.split('-')[-1])
if args.identity_pred:
self.predictor = nn.Identity()
self.predictor.weight = torch.eye(dim).cuda()
else:
self.predictor = nn.Linear(dim, dim)
# monitoring lincls
self.lincls = nn.Linear(2048, args.num_classes)
# normalization layer for the representations z1 and z2
if not args.standardization_off:
self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
else:
self.bn = nn.Identity()
self.ce = nn.CrossEntropyLoss()
self.mse_loss = nn.MSELoss(reduction='mean')
def forward(self, y1, y2, labels):
r1 = self.backbone(y1)
z1 = self.projector(r1)
z2 = self.projector(self.backbone(y2))
r1_detach = r1.detach()
lincls_pred = self.lincls(r1_detach)
lincls_loss = self.ce(lincls_pred, labels)
z1 = self.bn(z1)
z2 = self.bn(z2)
pred_z1 = self.predictor(z1)
pred_z2 = self.predictor(z2)
# Inv loss
inv_loss = (pred_z1-z2).pow_(2).mean() + (pred_z2-z1).pow_(2).mean()
# empirical auto-correlation matrix
corr_z1 = (pred_z1.T @ pred_z1).div_(len(pred_z1))
corr_z2 = (pred_z2.T @ pred_z2).div_(len(pred_z2))
w = self.predictor.weight
wtw = w.T @ w
cov_loss = (self.mse_loss(corr_z2, wtw) + self.mse_loss(corr_z1, wtw))*w.size(0)
return inv_loss, cov_loss, lincls_loss, lincls_pred
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class LARS(optim.Optimizer):
def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
weight_decay_filter=False, lars_adaptation_filter=False):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
eta=eta, weight_decay_filter=weight_decay_filter,
lars_adaptation_filter=lars_adaptation_filter)
super().__init__(params, defaults)
def exclude_bias_and_norm(self, p):
return p.ndim == 1
@torch.no_grad()
def step(self):
for g in self.param_groups:
#print(g['weight_decay'])
for p in g['params']:
dp = p.grad
if dp is None:
continue
if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
dp = dp.add(p, alpha=g['weight_decay'])
if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['eta'] * param_norm / update_norm), one), one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])
class GaussianBlur(object):
def __init__(self, p):
self.p = p
def __call__(self, img):
if random.random() < self.p:
sigma = random.random() * 1.9 + 0.1
return img.filter(ImageFilter.GaussianBlur(sigma))
else:
return img
class Solarization(object):
def __init__(self, p):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class Transform:
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.2, hue=0.1)],
p=0.8
),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=1.0),
Solarization(p=0.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.transform_prime = transforms.Compose([
transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.2, hue=0.1)],
p=0.8
),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=0.1),
Solarization(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
y1 = self.transform(x)
y2 = self.transform_prime(x)
return y1, y2
if __name__ == '__main__':
main()